Skip to content

Commit

Permalink
[ARITH][BUGFIX] Fix a bug of iter map floormod(x,2) simplify (#14571)
Browse files Browse the repository at this point in the history
This PR fixes a previous bug introduced in itermap detection.

Specifically, y - (x % 2) were simplified to y + (x % 2) - 1.
Which is wrong. The working rule is  y + ((x + 1) % 2) - 1,
but that rule will change the base iterator which is not desirable here.

We also removed the rule that simplifies  (x + 1) % 2 => 1 - x % 2
as benefit is minimal and it introduces extra negative co-efficients
that hurts analysis in general (as negative co-efficients are
harder in many cases).
  • Loading branch information
tqchen authored Apr 11, 2023
1 parent fb2ae1a commit f622e7f
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 65 deletions.
5 changes: 0 additions & 5 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,11 +898,6 @@ class IterMapRewriter : public ExprMutator {
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
if (sign < 0 && is_const_int(rhs->extent, 2)) {
lhs->base -= rhs->scale;
sign = 1;
}

tir::ExprDeepEqual equal;
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
Expand Down
17 changes: 13 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2));

// Simplify (x + 1) % 2 + x % 2 => 1
// NOTE: we should avoid simplifying (x + 1) %2 => 1 - x % 2 though
// mainly because introducing extra negative signs to expression can harm itertaor
// analysis which usually relies on positive itertator co-efficients.
TVM_TRY_REWRITE_IF(floormod(x + c1, 2) + floormod(x, 2), OneWithTypeLike(x),
floormod(c1.Eval()->value, 2) == 1);
TVM_TRY_REWRITE_IF(floormod(x, 2) + floormod(x + c1, 2), OneWithTypeLike(x),
floormod(c1.Eval()->value, 2) == 1);

// canonicalization rule
// will try rewrite again after canonicalization.

Expand Down Expand Up @@ -1018,10 +1027,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
c2.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1,
floormod(c1.Eval()->value, 2) == 1);
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
// (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x
TVM_TRY_REWRITE_IF(
floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2),
c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value || c1.Eval()->value < 0));

TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
c2.Eval()->value > 0);
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,5 +415,12 @@ def test_proddiv_simplify():
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))


def test_floormod_two():
ck = CanonicalChecker()
flm = tvm.te.floormod
x, y = te.var("x"), te.var("y")
ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)


if __name__ == "__main__":
tvm.testing.main()
10 changes: 5 additions & 5 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,14 @@ def test_compound():
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))


def test_compound_floormod_two():
def test_compound_floormod_two_regression():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

# extent of 2 are normalized to positive scale
assert_iter_sum_pattern(
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
# regression
# extent of 2 of negative scale cannot be normalized
assert_iter_sum_failure(
[fld(x, 2) * 2 - flm(x, 2) + 1],
dom_map=var_dom([(x, 8)]),
)

Expand Down
22 changes: 15 additions & 7 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ class TestSubIndex(BaseCompare):
TestCase(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)),
TestCase(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1),
TestCase(fld(y, 3) * 3 - y, 0 - flm(y, 3)),
TestCase(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6),
TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)),
TestCase(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6),
TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)),
TestCase(y - fld(y + z, 5) * 5, flm(y + z, 5) - z),
TestCase(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)),
TestCase(y - fld(y - z, 5) * 5, flm(y - z, 5) + z),
Expand Down Expand Up @@ -554,13 +554,15 @@ class TestFloormodIndex(BaseCompare):
TestCase(flm(x + 10, 2), flm(x, 2)),
TestCase(flm(x + y * 10, 2), flm(x, 2)),
TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
TestCase(flm(x * (-10), 2), 0),
TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
TestCase(flm(x + (-10), 2), flm(x, 2)),
TestCase(flm(x + y * (-10), 2), flm(x, 2)),
TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]),
TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]),
# NOTE: the followng case is covered by canonical simplify
# long range simplifcation in general can be covered by canonical simplify
# TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
)


Expand All @@ -574,13 +576,14 @@ class TestFloorModTwo(BaseCompare):
require identifying more related terms in order to apply.
(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
We should not introduce extra negative coeficient to iterators
however during simplification
"""

x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
# Removing offsets from floormod
TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
# Difference of floordiv yields floormod
Expand All @@ -592,8 +595,13 @@ class TestFloorModTwo(BaseCompare):
# Sum of floordiv and floormod to yield floordiv
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
# Removal of floormod where possible
TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
# regression: although we can rewrite (x + 1) %2 => 1 - x%2
# doing so would introduce negative co-efficient to iterators
# which makes later iter map detection harder, in principle we
# should not introduce additional negative signs of iterator in rewriting
TestCase(flm(x + 1, 2), flm(x + 1, 2)),
TestCase(flm(x + 5, 2), flm(x + 1, 2)),
TestCase(flm(x + 1, 2) * 8192, flm(x + 1, 2) * 8192, [x >= 0, x < 2]),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def transformed_simple_compute(
for i in T.serial(0, 15):
with T.block():
T.reads([A[tx, i + 1]])
T.writes([B[1 - i % 2, tx, 0]])
B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
T.writes([B[(i + 1) % 2, tx, 0]])
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads([B[i % 2, tx, 0]])
T.writes([C[tx, i]])
Expand Down Expand Up @@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation(
):
with T.block():
T.reads([A[tx, i + 1]])
T.writes([B[1 - i % 2, tx, 0]])
B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
T.writes([B[(i + 1) % 2, tx, 0]])
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads([B[i % 2, tx, 0]])
T.writes([C[tx, i]])
Expand Down Expand Up @@ -266,7 +266,7 @@ def transformed_three_stage_compute(
T.where(i == 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
with T.block():
T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
Expand All @@ -278,7 +278,7 @@ def transformed_three_stage_compute(
with T.block():
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] + T.float32(2)
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i])
Expand All @@ -291,7 +291,7 @@ def transformed_three_stage_compute(
T.where(i < 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2)
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
with T.block():
T.reads(C[0:2, tx, 0])
T.writes(D[tx, i + 14])
Expand Down Expand Up @@ -391,12 +391,12 @@ def transformed_dag_interleaving(
BS[tx, 0] = B[tx, i + 1] + T.float32(2)
with T.block():
T.reads(AS[tx, 0])
T.writes(AL[1 - i % 2, 0, 0])
AL[1 - i % 2, 0, 0] = AS[tx, 0]
T.writes(AL[(i + 1) % 2, 0, 0])
AL[(i + 1) % 2, 0, 0] = AS[tx, 0]
with T.block():
T.reads(BS[tx, 0])
T.writes(BL[1 - i % 2, 0, 0])
BL[1 - i % 2, 0, 0] = BS[tx, 0]
T.writes(BL[(i + 1) % 2, 0, 0])
BL[(i + 1) % 2, 0, 0] = BS[tx, 0]
with T.block():
T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0])
T.writes(C[tx, i])
Expand Down Expand Up @@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple(
for i in T.serial(0, 15):
with T.block():
T.reads([A[tx, i + 1, 0:16]])
T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
for j in T.serial(0, 16):
with T.block():
T.reads([A[tx, i + 1, j]])
T.writes([A_shared[1 - i % 2, tx, 0, j]])
A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
with T.block():
T.reads([A_shared[i % 2, tx, i, 0]])
T.writes([B[0, tx, i, 0]])
Expand All @@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple(
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[i % 2, tx, i, j + 1]])
T.writes([B[1 - j % 2, tx, i, 0]])
B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32(
2
)
T.writes([B[(j + 1) % 2, tx, i, 0]])
B[(j + 1) % 2, tx, i, 0] = A_shared[
i % 2, tx, 0, j + 1
] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
Expand All @@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple(
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[1, tx, 15, j + 1]])
T.writes([B[1 - j % 2, tx, 15, 0]])
B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
T.writes([B[(j + 1) % 2, tx, 15, 0]])
B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
Expand Down Expand Up @@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner(
for i in T.serial(0, 15):
with T.block():
T.reads([A[tx, i + 1, 0:16]])
T.writes([A_shared[1 - i % 2, tx, 0, 0:16]])
T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]])
for j in T.serial(0, 16):
with T.block():
T.reads([A[tx, i + 1, j]])
T.writes([A_shared[1 - i % 2, tx, 0, j]])
A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j]
T.writes([A_shared[(i + 1) % 2, tx, 0, j]])
A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j]
with T.block():
T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[i % 2, tx, i, j + 1]])
T.writes([B[1 - j % 2, tx, i, 0]])
B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32(
2
)
T.writes([B[(j + 1) % 2, tx, i, 0]])
B[(j + 1) % 2, tx, i, 0] = A_shared[
i % 2, tx, 0, j + 1
] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
with T.block():
T.reads([A_shared[1 - i % 2, tx, i + 1, 0]])
T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]])
T.writes([B[0, tx, i + 1, 0]])
B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] * T.float32(2)
B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2)
with T.block():
T.reads([B[1, tx, i, 0]])
T.writes([C[tx, i, 15]])
Expand All @@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner(
for j in T.serial(0, 15):
with T.block():
T.reads([A_shared[1, tx, 15, j + 1]])
T.writes([B[1 - j % 2, tx, 15, 0]])
B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
T.writes([B[(j + 1) % 2, tx, 15, 0]])
B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
Expand Down Expand Up @@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[tx, i, j + 1]])
T.writes([B[1 - j % 2, tx, i, 0]])
B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
T.writes([B[(j + 1) % 2, tx, i, 0]])
B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
Expand Down Expand Up @@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[tx, 15, j + 1]])
T.writes([B[1 - j % 2, tx, 15, 0]])
B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
T.writes([B[(j + 1) % 2, tx, 15, 0]])
B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
Expand Down Expand Up @@ -929,25 +929,27 @@ def transformed_nested_pipeline_double_buffer(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[i % 2, tx, i, j + 1]])
T.writes([B[1 - j % 2, tx, i, 0]])
B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(2)
T.writes([B[(j + 1) % 2, tx, i, 0]])
B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(
2
)
with T.block():
T.reads([B[j % 2, tx, i, 0]])
T.writes([C[tx, i, j]])
C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
with T.block():
T.reads([A_shared[tx, 0, 0:16]])
T.writes([A_local[1 - i % 2, 0, 0, 0:16]])
T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
for j in T.serial(0, 16):
with T.block():
T.reads([A_shared[tx, 0, j]])
T.writes([A_local[1 - i % 2, 0, 0, j]])
T.writes([A_local[(i + 1) % 2, 0, 0, j]])
T.block_attr({"double_buffer_scope": 0})
A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i + 1, j]
A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j]
with T.block():
T.reads([A_local[1 - i % 2, tx, i + 1, 0]])
T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]])
T.writes([B[0, tx, i + 1, 0]])
B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] * T.float32(2)
B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2)
with T.block():
T.reads([B[1, tx, i, 0]])
T.writes([C[tx, i, 15]])
Expand All @@ -961,8 +963,8 @@ def transformed_nested_pipeline_double_buffer(
for j in T.serial(0, 15):
with T.block():
T.reads([A_local[1, tx, 15, j + 1]])
T.writes([B[1 - j % 2, tx, 15, 0]])
B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2)
T.writes([B[(j + 1) % 2, tx, 15, 0]])
B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2)
with T.block():
T.reads([B[j % 2, tx, 15, 0]])
T.writes([C[tx, 15, j]])
Expand Down

0 comments on commit f622e7f

Please sign in to comment.