@@ -335,6 +335,8 @@ class SumExprNode : public CanonicalExprNode {
335335 * \return whether the cast can be safely pushed to children
336336 */
337337 bool CanPushCastToChildren (DataType dtype, Analyzer* analyzer) const {
338+ bool is_min_value = dtype.bits () == 64 ? base == std::numeric_limits<int64_t >::lowest ()
339+ : base == -(1LL << (dtype.bits () - 1 ));
338340 // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
339341 // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
340342 // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
@@ -351,7 +353,7 @@ class SumExprNode : public CanonicalExprNode {
351353 }
352354 }
353355 }
354- if (base > 0 ) {
356+ if (base > 0 || is_min_value ) {
355357 res = res + make_const (dtype, base);
356358 if (!CastIsSafe (dtype, res, analyzer)) {
357359 return false ;
@@ -366,7 +368,7 @@ class SumExprNode : public CanonicalExprNode {
366368 }
367369 }
368370 }
369- if (base < 0 ) {
371+ if (base < 0 && !is_min_value ) {
370372 res = res - make_const (dtype, -base);
371373 if (!CastIsSafe (dtype, res, analyzer)) {
372374 return false ;
@@ -497,14 +499,16 @@ class SumExprNode : public CanonicalExprNode {
497499 return args;
498500 }
499501 static PrimExpr Normalize_ (DataType dtype, const std::vector<SplitExpr>& args, int64_t base) {
502+ bool is_min_value = dtype.bits () == 64 ? base == std::numeric_limits<int64_t >::lowest ()
503+ : base == -(1LL << (dtype.bits () - 1 ));
500504 // Positive scales first
501505 PrimExpr res = make_const (dtype, 0 );
502506 for (size_t i = 0 ; i < args.size (); ++i) {
503507 if (args[i]->scale > 0 ) {
504508 res = res + args[i]->Normalize ();
505509 }
506510 }
507- if (base > 0 ) {
511+ if (base > 0 || is_min_value ) {
508512 res = res + make_const (dtype, base);
509513 }
510514 // negative scales follows using sub.
@@ -513,7 +517,7 @@ class SumExprNode : public CanonicalExprNode {
513517 res = res - args[i]->NormalizeWithScale (-1 );
514518 }
515519 }
516- if (base < 0 ) {
520+ if (base < 0 && !is_min_value ) {
517521 res = res - make_const (dtype, -base);
518522 }
519523 return res;
0 commit comments