diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index de86f5b2859f..ab0576d372d6 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -1601,6 +1601,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, wake_userfault(vma->vm_userfaultfd_ctx.ctx, &range); } + /* Reset ptes for the whole vma range if wr-protected */ + if (userfaultfd_wp(vma)) + uffd_wp_range(mm, vma, start, vma_end - start, false); + new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS; prev = vma_merge(mm, prev, start, vma_end, new_flags, vma->anon_vma, vma->vm_file, vma->vm_pgoff, diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h index 732b522bacb7..e1b8a915e9e9 100644 --- a/include/linux/userfaultfd_k.h +++ b/include/linux/userfaultfd_k.h @@ -73,6 +73,8 @@ extern ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long dst_start, extern int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start, unsigned long len, bool enable_wp, atomic_t *mmap_changing); +extern void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *vma, + unsigned long start, unsigned long len, bool enable_wp); /* mm helpers */ static inline bool is_mergeable_vm_userfaultfd_ctx(struct vm_area_struct *vma, diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c index 07d3befc80e4..7327b2573f7c 100644 --- a/mm/userfaultfd.c +++ b/mm/userfaultfd.c @@ -703,14 +703,29 @@ ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long start, mmap_changing, 0); } +void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma, + unsigned long start, unsigned long len, bool enable_wp) +{ + struct mmu_gather tlb; + pgprot_t newprot; + + if (enable_wp) + newprot = vm_get_page_prot(dst_vma->vm_flags & ~(VM_WRITE)); + else + newprot = vm_get_page_prot(dst_vma->vm_flags); + + tlb_gather_mmu(&tlb, dst_mm); + change_protection(&tlb, dst_vma, start, start + len, newprot, + enable_wp ? MM_CP_UFFD_WP : MM_CP_UFFD_WP_RESOLVE); + tlb_finish_mmu(&tlb); +} + int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start, unsigned long len, bool enable_wp, atomic_t *mmap_changing) { struct vm_area_struct *dst_vma; unsigned long page_mask; - struct mmu_gather tlb; - pgprot_t newprot; int err; /* @@ -750,15 +765,7 @@ int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start, goto out_unlock; } - if (enable_wp) - newprot = vm_get_page_prot(dst_vma->vm_flags & ~(VM_WRITE)); - else - newprot = vm_get_page_prot(dst_vma->vm_flags); - - tlb_gather_mmu(&tlb, dst_mm); - change_protection(&tlb, dst_vma, start, start + len, newprot, - enable_wp ? MM_CP_UFFD_WP : MM_CP_UFFD_WP_RESOLVE); - tlb_finish_mmu(&tlb); + uffd_wp_range(dst_mm, dst_vma, start, len, enable_wp); err = 0; out_unlock: