diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 3dd6ee1c2b596..ec4740254b4b2 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -51,7 +51,7 @@ def convert_iter_expr(expr): def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) - if extent == 1: + if extent is None: assert len(sum_expr.args) == 0 else: assert len(sum_expr.args) == 1 @@ -69,12 +69,12 @@ def test_trivial(): assert len(res) == 3 assert_iter_sum_pattern(res[0], 3, 0) assert_iter_sum_pattern(res[1], 4, 0) - assert_iter_sum_pattern(res[2], 1, 3) + assert_iter_sum_pattern(res[2], None, 3) res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) assert len(res) == 2 assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 1, 3) + assert_iter_sum_pattern(res[1], None, 3) # not independent res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 28399498c784f..e7d5f125dc681 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -545,5 +545,35 @@ def test_transform_with_reduction(): tvm.lower(s, [A, B]) +shape, transform = tvm.testing.parameters( + ([1, 8], lambda n, i: [i, n]), + ([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]), + ([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]), +) + + +def test_size_one_buffer(shape, transform): + # This test is to catch a failure mode that occurred if a + # transformation were applied to a te.compute buffer, and one of + # the dimensions of the buffer was 1. Prior to bugfix, + # arith::DetectIterMap would fold the variable as a constant, + # causing an error when attempting to solve for the variable using + # arith::InverseAffineIterMap. + + dtype = "int8" + A = te.placeholder(shape, dtype, name="A") + B = te.compute( + shape=A.shape, + fcompute=lambda *indices: A[indices].astype(dtype), + name="B", + ) + s = te.create_schedule(B.op) + + # If layout transformation is on the output buffer, and any + # dimension of the output buffer is 1, failure occurs in + # CheckFusePattern. + s[B].transform_layout(transform) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))