Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing SDFGState._read_and_write_sets() #1747

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,26 +786,22 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr,
# NOTE: In certain cases the corresponding subset might be None, in this case
# we assume that the whole array is written, which is the default behaviour.
ac_desc = n.desc(self.sdfg)
ac_size = ac_desc.total_size
in_subsets = dict()
for in_edge in in_edges:
# Ensure that if the destination subset is not given, our assumption, that the
# whole array is written to, is valid, by testing if the memlet transfers the
# whole array.
assert (in_edge.data.dst_subset is not None) or (in_edge.data.num_elements() == ac_size)
in_subsets[in_edge] = (
sbs.Range.from_array(ac_desc)
if in_edge.data.dst_subset is None
else in_edge.data.dst_subset
in_subsets = {
in_edge: (
sbs.Range.from_array(ac_desc)
if in_edge.data.dst_subset is None
else in_edge.data.dst_subset
)
out_subsets = dict()
for out_edge in out_edges:
assert (out_edge.data.src_subset is not None) or (out_edge.data.num_elements() == ac_size)
out_subsets[out_edge] = (
for in_edge in in_edges
}
out_subsets = {
out_edge: (
sbs.Range.from_array(ac_desc)
if out_edge.data.src_subset is None
else out_edge.data.src_subset
)
for out_edge in out_edges
}

# Update the read and write sets of the subgraph.
if in_edges:
Expand Down
2 changes: 1 addition & 1 deletion tests/npbench/misc/stockham_fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,4 @@ def test_fpga():
elif target == "gpu":
run_stockham_fft(dace.dtypes.DeviceType.GPU)
elif target == "fpga":
run_stockham_fft(dace.dtypes.DeviceType.FPGA)
run_stockham_fft(dace.dtypes.DeviceType.FPGA)
47 changes: 47 additions & 0 deletions tests/sdfg/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,52 @@ def test_read_and_write_set_selection():
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_read_and_write_set_names():
sdfg = dace.SDFG('test_read_and_write_set_names')
state = sdfg.add_state(is_start_block=True)

# The arrays use different symbols for their sizes, but they are known to be the
# same. This happens for example if the SDFG is the result of some automatic
# translation from another IR, such as GTIR in GT4Py.
names = ["A", "B"]
for name in names:
sdfg.add_symbol(f"{name}_size_0", dace.int32)
sdfg.add_symbol(f"{name}_size_1", dace.int32)
sdfg.add_array(
name,
shape=(f"{name}_size_0", f"{name}_size_1"),
dtype=dace.float64,
transient=False,
)
A, B = (state.add_access(name) for name in names)

# Print copy `A` into `B`.
# Because, `dst_subset` is `None` we expect that everything is transferred.
state.add_nedge(
A,
B,
dace.Memlet("A[0:A_size_0, 0:A_size_1]"),
)
expected_read_set = {
"A": [sbs.Range.from_string("0:A_size_0, 0:A_size_1")],
}
expected_write_set = {
"B": [sbs.Range.from_string("0:B_size_0, 0:B_size_1")],
}
read_set, write_set = state._read_and_write_sets()

for expected_sets, computed_sets in [(expected_read_set, read_set), (expected_write_set, write_set)]:
assert expected_sets.keys() == computed_sets.keys(), f"Expected the set to contain '{expected_sets.keys()}' but got '{computed_sets.keys()}'."
for access_data in expected_sets.keys():
for exp in expected_sets[access_data]:
found_match = False
for res in computed_sets[access_data]:
if res == exp:
found_match = True
break
assert found_match, f"Could not find the subset '{exp}' only got '{computed_sets}'"


def test_add_mapped_tasklet():
sdfg = dace.SDFG("test_add_mapped_tasklet")
state = sdfg.add_state(is_start_block=True)
Expand Down Expand Up @@ -173,5 +219,6 @@ def test_add_mapped_tasklet():
test_read_and_write_set_filter()
test_read_write_set()
test_read_write_set_y_formation()
test_read_and_write_set_names()
test_deepcopy_state()
test_add_mapped_tasklet()