diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c index 6cd4dd631a2f..8eef3ed5fe04 100644 --- a/arch/x86/kvm/mmu/tdp_mmu.c +++ b/arch/x86/kvm/mmu/tdp_mmu.c @@ -1506,6 +1506,16 @@ void kvm_tdp_mmu_try_split_huge_pages(struct kvm *kvm, } } +static bool tdp_mmu_need_write_protect(struct kvm_mmu_page *sp) +{ + /* + * All TDP MMU shadow pages share the same role as their root, aside + * from level, so it is valid to key off any shadow page to determine if + * write protection is needed for an entire tree. + */ + return kvm_mmu_page_ad_need_write_protect(sp) || !kvm_ad_enabled(); +} + /* * Clear the dirty status of all the SPTEs mapping GFNs in the memslot. If * AD bits are enabled, this will involve clearing the dirty bit on each SPTE. @@ -1516,7 +1526,8 @@ void kvm_tdp_mmu_try_split_huge_pages(struct kvm *kvm, static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root, gfn_t start, gfn_t end) { - u64 dbit = kvm_ad_enabled() ? shadow_dirty_mask : PT_WRITABLE_MASK; + const u64 dbit = tdp_mmu_need_write_protect(root) ? PT_WRITABLE_MASK : + shadow_dirty_mask; struct tdp_iter iter; bool spte_set = false; @@ -1530,7 +1541,7 @@ retry: if (!is_shadow_present_pte(iter.old_spte)) continue; - KVM_MMU_WARN_ON(kvm_ad_enabled() && + KVM_MMU_WARN_ON(dbit == shadow_dirty_mask && spte_ad_need_write_protect(iter.old_spte)); if (!(iter.old_spte & dbit)) @@ -1578,8 +1589,8 @@ bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm, static void clear_dirty_pt_masked(struct kvm *kvm, struct kvm_mmu_page *root, gfn_t gfn, unsigned long mask, bool wrprot) { - u64 dbit = (wrprot || !kvm_ad_enabled()) ? PT_WRITABLE_MASK : - shadow_dirty_mask; + const u64 dbit = (wrprot || tdp_mmu_need_write_protect(root)) ? PT_WRITABLE_MASK : + shadow_dirty_mask; struct tdp_iter iter; lockdep_assert_held_write(&kvm->mmu_lock); @@ -1591,7 +1602,7 @@ static void clear_dirty_pt_masked(struct kvm *kvm, struct kvm_mmu_page *root, if (!mask) break; - KVM_MMU_WARN_ON(kvm_ad_enabled() && + KVM_MMU_WARN_ON(dbit == shadow_dirty_mask && spte_ad_need_write_protect(iter.old_spte)); if (iter.level > PG_LEVEL_4K ||