diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index a56336fa8d..1ff286b85c 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -1146,10 +1146,15 @@ def change_data(transient_array, shape, strides, total_size, offset, lifetime, s # by reconnecting their adjacent edges to nodes outside the subgraph. # NOTE: Currently limited to cases where there is a single source and sink # if there are multiple intermediate accesses for the same data. + # NOTE: Currently limited to intermediate data that do not have a separate output node + + # Filter out outputs + output_data = set([n.data for n in out_nodes]) + true_intermediate_nodes = set([n for n in intermediate_nodes if n.data not in output_data]) # Sort intermediate nodes by data name intermediate_data = dict() - for acc in intermediate_nodes: + for acc in true_intermediate_nodes: if acc.data in intermediate_data: intermediate_data[acc.data].append(acc) else: