diff --git a/arch/arm/include/asm/device.h b/arch/arm/include/asm/device.h
index 220ba207be91..36ec9c8f6e16 100644
--- a/arch/arm/include/asm/device.h
+++ b/arch/arm/include/asm/device.h
@@ -15,6 +15,9 @@ struct dev_archdata {
 #endif
 #ifdef CONFIG_ARM_DMA_USE_IOMMU
 	struct dma_iommu_mapping	*mapping;
+#endif
+#ifdef CONFIG_XEN
+	const struct dma_map_ops *dev_dma_ops;
 #endif
 	bool dma_coherent;
 };
diff --git a/arch/arm/include/asm/dma-mapping.h b/arch/arm/include/asm/dma-mapping.h
index 716656925975..680d3f3889e7 100644
--- a/arch/arm/include/asm/dma-mapping.h
+++ b/arch/arm/include/asm/dma-mapping.h
@@ -16,19 +16,9 @@
 extern const struct dma_map_ops arm_dma_ops;
 extern const struct dma_map_ops arm_coherent_dma_ops;
 
-static inline const struct dma_map_ops *__generic_dma_ops(struct device *dev)
-{
-	if (dev && dev->dma_ops)
-		return dev->dma_ops;
-	return &arm_dma_ops;
-}
-
 static inline const struct dma_map_ops *get_arch_dma_ops(struct bus_type *bus)
 {
-	if (xen_initial_domain())
-		return xen_dma_ops;
-	else
-		return __generic_dma_ops(NULL);
+	return &arm_dma_ops;
 }
 
 #define HAVE_ARCH_DMA_SUPPORTED 1
diff --git a/arch/arm/mm/dma-mapping.c b/arch/arm/mm/dma-mapping.c
index 475811f5383a..0268584f1fa0 100644
--- a/arch/arm/mm/dma-mapping.c
+++ b/arch/arm/mm/dma-mapping.c
@@ -2414,6 +2414,13 @@ void arch_setup_dma_ops(struct device *dev, u64 dma_base, u64 size,
 		dma_ops = arm_get_dma_map_ops(coherent);
 
 	set_dma_ops(dev, dma_ops);
+
+#ifdef CONFIG_XEN
+	if (xen_initial_domain()) {
+		dev->archdata.dev_dma_ops = dev->dma_ops;
+		dev->dma_ops = xen_dma_ops;
+	}
+#endif
 }
 
 void arch_teardown_dma_ops(struct device *dev)
diff --git a/arch/arm64/include/asm/device.h b/arch/arm64/include/asm/device.h
index 73d5bab015eb..5a5fa47a6b18 100644
--- a/arch/arm64/include/asm/device.h
+++ b/arch/arm64/include/asm/device.h
@@ -19,6 +19,9 @@
 struct dev_archdata {
 #ifdef CONFIG_IOMMU_API
 	void *iommu;			/* private IOMMU data */
+#endif
+#ifdef CONFIG_XEN
+	const struct dma_map_ops *dev_dma_ops;
 #endif
 	bool dma_coherent;
 };
diff --git a/arch/arm64/include/asm/dma-mapping.h b/arch/arm64/include/asm/dma-mapping.h
index 505756cdc67a..5392dbeffa45 100644
--- a/arch/arm64/include/asm/dma-mapping.h
+++ b/arch/arm64/include/asm/dma-mapping.h
@@ -27,11 +27,8 @@
 #define DMA_ERROR_CODE	(~(dma_addr_t)0)
 extern const struct dma_map_ops dummy_dma_ops;
 
-static inline const struct dma_map_ops *__generic_dma_ops(struct device *dev)
+static inline const struct dma_map_ops *get_arch_dma_ops(struct bus_type *bus)
 {
-	if (dev && dev->dma_ops)
-		return dev->dma_ops;
-
 	/*
 	 * We expect no ISA devices, and all other DMA masters are expected to
 	 * have someone call arch_setup_dma_ops at device creation time.
@@ -39,14 +36,6 @@ static inline const struct dma_map_ops *__generic_dma_ops(struct device *dev)
 	return &dummy_dma_ops;
 }
 
-static inline const struct dma_map_ops *get_arch_dma_ops(struct bus_type *bus)
-{
-	if (xen_initial_domain())
-		return xen_dma_ops;
-	else
-		return __generic_dma_ops(NULL);
-}
-
 void arch_setup_dma_ops(struct device *dev, u64 dma_base, u64 size,
 			const struct iommu_ops *iommu, bool coherent);
 #define arch_setup_dma_ops	arch_setup_dma_ops
diff --git a/arch/arm64/mm/dma-mapping.c b/arch/arm64/mm/dma-mapping.c
index 81cdb2e844ed..7f8b37e85a2b 100644
--- a/arch/arm64/mm/dma-mapping.c
+++ b/arch/arm64/mm/dma-mapping.c
@@ -977,4 +977,11 @@ void arch_setup_dma_ops(struct device *dev, u64 dma_base, u64 size,
 
 	dev->archdata.dma_coherent = coherent;
 	__iommu_setup_dma_ops(dev, dma_base, size, iommu);
+
+#ifdef CONFIG_XEN
+	if (xen_initial_domain()) {
+		dev->archdata.dev_dma_ops = dev->dma_ops;
+		dev->dma_ops = xen_dma_ops;
+	}
+#endif
 }
diff --git a/include/xen/arm/page-coherent.h b/include/xen/arm/page-coherent.h
index 95ce6ac3a971..b0a2bfc8d647 100644
--- a/include/xen/arm/page-coherent.h
+++ b/include/xen/arm/page-coherent.h
@@ -2,8 +2,16 @@
 #define _ASM_ARM_XEN_PAGE_COHERENT_H
 
 #include <asm/page.h>
+#include <asm/dma-mapping.h>
 #include <linux/dma-mapping.h>
 
+static inline const struct dma_map_ops *__generic_dma_ops(struct device *dev)
+{
+	if (dev && dev->archdata.dev_dma_ops)
+		return dev->archdata.dev_dma_ops;
+	return get_arch_dma_ops(NULL);
+}
+
 void __xen_dma_map_page(struct device *hwdev, struct page *page,
 	     dma_addr_t dev_addr, unsigned long offset, size_t size,
 	     enum dma_data_direction dir, unsigned long attrs);