diff --git a/drivers/iommu/apple-dart.c b/drivers/iommu/apple-dart.c index 42666617803d..1fc64a42c890 100644 --- a/drivers/iommu/apple-dart.c +++ b/drivers/iommu/apple-dart.c @@ -835,6 +835,29 @@ static void apple_dart_release_group(void *iommu_data) mutex_unlock(&apple_dart_groups_lock); } +static int apple_dart_merge_master_cfg(struct apple_dart_master_cfg *dst, + struct apple_dart_master_cfg *src) +{ + /* + * We know that this function is only called for groups returned from + * pci_device_group and that all Apple Silicon platforms never spread + * PCIe devices from the same bus across multiple DARTs such that we can + * just assume that both src and dst only have the same single DART. + */ + if (src->stream_maps[1].dart) + return -EINVAL; + if (dst->stream_maps[1].dart) + return -EINVAL; + if (src->stream_maps[0].dart != dst->stream_maps[0].dart) + return -EINVAL; + + bitmap_or(dst->stream_maps[0].sidmap, + dst->stream_maps[0].sidmap, + src->stream_maps[0].sidmap, + dst->stream_maps[0].dart->num_streams); + return 0; +} + static struct iommu_group *apple_dart_device_group(struct device *dev) { int i, sid; @@ -876,14 +899,28 @@ static struct iommu_group *apple_dart_device_group(struct device *dev) if (!group) goto out; - group_master_cfg = kmemdup(cfg, sizeof(*group_master_cfg), GFP_KERNEL); - if (!group_master_cfg) { - iommu_group_put(group); - goto out; - } + group_master_cfg = iommu_group_get_iommudata(group); + if (group_master_cfg) { + int ret; - iommu_group_set_iommudata(group, group_master_cfg, - apple_dart_release_group); + ret = apple_dart_merge_master_cfg(group_master_cfg, cfg); + if (ret) { + dev_err(dev, "Failed to merge DART IOMMU grups.\n"); + iommu_group_put(group); + res = ERR_PTR(ret); + goto out; + } + } else { + group_master_cfg = kmemdup(cfg, sizeof(*group_master_cfg), + GFP_KERNEL); + if (!group_master_cfg) { + iommu_group_put(group); + goto out; + } + + iommu_group_set_iommudata(group, group_master_cfg, + apple_dart_release_group); + } for_each_stream_map(i, cfg, stream_map) for_each_set_bit(sid, stream_map->sidmap, stream_map->dart->num_streams)