Skip to content

Commit ce45b2a

Browse files
committed
Solve on canonical simplification side
1 parent f948d27 commit ce45b2a

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/arith/canonical_simplify.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ class SumExprNode : public CanonicalExprNode {
335335
* \return whether the cast can be safely pushed to children
336336
*/
337337
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
338+
bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits<int64_t>::lowest()
339+
: base == -(1LL << (dtype.bits() - 1));
338340
// cast(dtype, arg_1 + arg_2 + ... arg_n) ==
339341
// cast(dtype, arg_1) + ... + cast(dtype, arg_n)
340342
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
@@ -351,7 +353,7 @@ class SumExprNode : public CanonicalExprNode {
351353
}
352354
}
353355
}
354-
if (base > 0) {
356+
if (base > 0 || is_min_value) {
355357
res = res + make_const(dtype, base);
356358
if (!CastIsSafe(dtype, res, analyzer)) {
357359
return false;
@@ -366,7 +368,7 @@ class SumExprNode : public CanonicalExprNode {
366368
}
367369
}
368370
}
369-
if (base < 0) {
371+
if (base < 0 && !is_min_value) {
370372
res = res - make_const(dtype, -base);
371373
if (!CastIsSafe(dtype, res, analyzer)) {
372374
return false;
@@ -497,14 +499,16 @@ class SumExprNode : public CanonicalExprNode {
497499
return args;
498500
}
499501
static PrimExpr Normalize_(DataType dtype, const std::vector<SplitExpr>& args, int64_t base) {
502+
bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits<int64_t>::lowest()
503+
: base == -(1LL << (dtype.bits() - 1));
500504
// Positive scales first
501505
PrimExpr res = make_const(dtype, 0);
502506
for (size_t i = 0; i < args.size(); ++i) {
503507
if (args[i]->scale > 0) {
504508
res = res + args[i]->Normalize();
505509
}
506510
}
507-
if (base > 0) {
511+
if (base > 0 || is_min_value) {
508512
res = res + make_const(dtype, base);
509513
}
510514
// negative scales follows using sub.
@@ -513,7 +517,7 @@ class SumExprNode : public CanonicalExprNode {
513517
res = res - args[i]->NormalizeWithScale(-1);
514518
}
515519
}
516-
if (base < 0) {
520+
if (base < 0 && !is_min_value) {
517521
res = res - make_const(dtype, -base);
518522
}
519523
return res;

tests/python/unittest/test_arith_canonical_simplify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ def test_simplify_cast():
375375
def test_simplify_normalize_min_value_not_error():
376376
ck = CanonicalChecker()
377377
x = te.var("x", "int32")
378-
expr = te.min_value_symmetric("int32") - x == 0
378+
expr = te.min_value("int32") - x == 0
379379
# The simplify should not error in this case.
380-
ck.verify(expr, 0 - x == te.max_value("int32"))
380+
ck.verify(expr, x == te.min_value("int32"))
381381

382382

383383
if __name__ == "__main__":

0 commit comments

Comments
 (0)