Skip to content

Commit

Permalink
[Arith] Updated arith::DetectIterMap to keep extent=1 components
Browse files Browse the repository at this point in the history
Previously, arith::DetectIterMap simplified the output expression by
replacing iteration variables with extent==1 with their value.  This
prevented the return value from being used in
arith::InverseAffineIterMap to solve for the variable, as it no longer
existed in the returned expressions.

This commit changes arith::DetectIterMap to keep the iteration
variable even if extent==1, and adds a motivating unit test that
requires this updated behavior.
  • Loading branch information
Lunderberg committed Apr 12, 2022
1 parent 11d22bd commit df4919c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def convert_iter_expr(expr):
def assert_iter_sum_pattern(sum_expr, extent, base, scale=1):
"""Check the sum expr have the right pattern."""
assert isinstance(sum_expr, tvm.arith.IterSumExpr)
if extent == 1:
if extent is None:
assert len(sum_expr.args) == 0
else:
assert len(sum_expr.args) == 1
Expand All @@ -69,12 +69,12 @@ def test_trivial():
assert len(res) == 3
assert_iter_sum_pattern(res[0], 3, 0)
assert_iter_sum_pattern(res[1], 4, 0)
assert_iter_sum_pattern(res[2], 1, 3)
assert_iter_sum_pattern(res[2], None, 3)

res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y]))
assert len(res) == 2
assert_iter_sum_pattern(res[0], 3, 0)
assert_iter_sum_pattern(res[1], 1, 3)
assert_iter_sum_pattern(res[1], None, 3)

# not independent
res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y]))
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,5 +545,35 @@ def test_transform_with_reduction():
tvm.lower(s, [A, B])


shape, transform = tvm.testing.parameters(
([1, 8], lambda n, i: [i, n]),
([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]),
([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]),
)


def test_size_one_buffer(shape, transform):
# This test is to catch a failure mode that occurred if a
# transformation were applied to a te.compute buffer, and one of
# the dimensions of the buffer was 1. Prior to bugfix,
# arith::DetectIterMap would fold the variable as a constant,
# causing an error when attempting to solve for the variable using
# arith::InverseAffineIterMap.

dtype = "int8"
A = te.placeholder(shape, dtype, name="A")
B = te.compute(
shape=A.shape,
fcompute=lambda *indices: A[indices].astype(dtype),
name="B",
)
s = te.create_schedule(B.op)

# If layout transformation is on the output buffer, and any
# dimension of the output buffer is 1, failure occurs in
# CheckFusePattern.
s[B].transform_layout(transform)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit df4919c

Please sign in to comment.