mempolicy: convert to vma iterator

Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-21-Liam.Howlett@oracle.com
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
This commit is contained in:
Liam R. Howlett 2023-01-20 11:26:21 -05:00 committed by Andrew Morton
parent e552cdb853
commit f10c2abcda

View File

@ -787,24 +787,21 @@ static int vma_replace_policy(struct vm_area_struct *vma,
static int mbind_range(struct mm_struct *mm, unsigned long start, static int mbind_range(struct mm_struct *mm, unsigned long start,
unsigned long end, struct mempolicy *new_pol) unsigned long end, struct mempolicy *new_pol)
{ {
MA_STATE(mas, &mm->mm_mt, start, start); VMA_ITERATOR(vmi, mm, start);
struct vm_area_struct *prev; struct vm_area_struct *prev;
struct vm_area_struct *vma; struct vm_area_struct *vma;
int err = 0; int err = 0;
pgoff_t pgoff; pgoff_t pgoff;
prev = mas_prev(&mas, 0); prev = vma_prev(&vmi);
if (unlikely(!prev)) vma = vma_find(&vmi, end);
mas_set(&mas, start);
vma = mas_find(&mas, end - 1);
if (WARN_ON(!vma)) if (WARN_ON(!vma))
return 0; return 0;
if (start > vma->vm_start) if (start > vma->vm_start)
prev = vma; prev = vma;
for (; vma; vma = mas_next(&mas, end - 1)) { do {
unsigned long vmstart = max(start, vma->vm_start); unsigned long vmstart = max(start, vma->vm_start);
unsigned long vmend = min(end, vma->vm_end); unsigned long vmend = min(end, vma->vm_end);
@ -813,29 +810,23 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
pgoff = vma->vm_pgoff + pgoff = vma->vm_pgoff +
((vmstart - vma->vm_start) >> PAGE_SHIFT); ((vmstart - vma->vm_start) >> PAGE_SHIFT);
prev = vma_merge(mm, prev, vmstart, vmend, vma->vm_flags, prev = vmi_vma_merge(&vmi, mm, prev, vmstart, vmend, vma->vm_flags,
vma->anon_vma, vma->vm_file, pgoff, vma->anon_vma, vma->vm_file, pgoff,
new_pol, vma->vm_userfaultfd_ctx, new_pol, vma->vm_userfaultfd_ctx,
anon_vma_name(vma)); anon_vma_name(vma));
if (prev) { if (prev) {
/* vma_merge() invalidated the mas */
mas_pause(&mas);
vma = prev; vma = prev;
goto replace; goto replace;
} }
if (vma->vm_start != vmstart) { if (vma->vm_start != vmstart) {
err = split_vma(vma->vm_mm, vma, vmstart, 1); err = vmi_split_vma(&vmi, vma->vm_mm, vma, vmstart, 1);
if (err) if (err)
goto out; goto out;
/* split_vma() invalidated the mas */
mas_pause(&mas);
} }
if (vma->vm_end != vmend) { if (vma->vm_end != vmend) {
err = split_vma(vma->vm_mm, vma, vmend, 0); err = vmi_split_vma(&vmi, vma->vm_mm, vma, vmend, 0);
if (err) if (err)
goto out; goto out;
/* split_vma() invalidated the mas */
mas_pause(&mas);
} }
replace: replace:
err = vma_replace_policy(vma, new_pol); err = vma_replace_policy(vma, new_pol);
@ -843,7 +834,7 @@ replace:
goto out; goto out;
next: next:
prev = vma; prev = vma;
} } for_each_vma_range(vmi, vma, end);
out: out:
return err; return err;