@@ -74,7 +74,7 @@ bool is_safe_for_add(const Expr &e, int max_depth) {
7474 } else if (cast->type .bits () == cast->value .type ().bits ()) {
7575 return is_safe_for_add (cast->value , max_depth);
7676 }
77- } else if (Call::as_intrinsic (e, {Call::widening_add, Call::widening_sub})) {
77+ } else if (Call::as_intrinsic (e, {Call::widening_add, Call::widening_sub, Call::widen_right_add, Call::widen_right_sub })) {
7878 return true ;
7979 }
8080 return false ;
@@ -131,20 +131,47 @@ Expr to_rounding_shift(const Call *c) {
131131 }
132132 Expr round;
133133 if (c->is_intrinsic (Call::shift_right)) {
134- round = simplify (( make_one (round_type) << max (cast (b.type ().with_bits (round_type.bits ()), b), 0 )) / 2 ) ;
134+ round = ( make_one (round_type) << max (cast (b.type ().with_bits (round_type.bits ()), b), 0 )) / 2 ;
135135 } else {
136- round = simplify (( make_one (round_type) >> min (cast (b.type ().with_bits (round_type.bits ()), b), 0 )) / 2 ) ;
136+ round = ( make_one (round_type) >> min (cast (b.type ().with_bits (round_type.bits ()), b), 0 )) / 2 ;
137137 }
138+ // Input expressions are simplified before running find_intrinsics, but b
139+ // has been lifted here so we need to lower_intrinsics before simplifying
140+ // and re-lifting. Should we move this code into the FindIntrinsics class
141+ // to make it easier to lift round?
142+ round = lower_intrinsics (round);
143+ round = simplify (round);
144+ round = find_intrinsics (round);
138145
139146 // We can always handle widening adds.
140147 if (const Call *add = Call::as_intrinsic (a, {Call::widening_add})) {
141- if (can_prove (lower_intrinsics (add->args [0 ]) == round)) {
148+ if (can_prove (lower_intrinsics (add->args [0 ] == round) )) {
142149 return rounding_shift (cast (add->type , add->args [1 ]), b);
143- } else if (can_prove (lower_intrinsics (add->args [1 ]) == round)) {
150+ } else if (can_prove (lower_intrinsics (add->args [1 ] == round) )) {
144151 return rounding_shift (cast (add->type , add->args [0 ]), b);
145152 }
146153 }
147154
155+ if (const Call *add = Call::as_intrinsic (a, {Call::widen_right_add})) {
156+ if (can_prove (lower_intrinsics (add->args [1 ] == round))) {
157+ return rounding_shift (cast (add->type , add->args [0 ]), b);
158+ }
159+ }
160+ // Also need to handle the annoying case of a reinterpret wrapping a widen_right_add
161+ // TODO: this pattern makes me want to change the semantics of this op.
162+ if (const Reinterpret *reinterp = a.as <Reinterpret>()) {
163+ if (reinterp->type .bits () == reinterp->value .type ().bits ()) {
164+ if (const Call *add = Call::as_intrinsic (reinterp->value , {Call::widen_right_add})) {
165+ if (can_prove (lower_intrinsics (add->args [1 ] == round))) {
166+ // We expect the first operand to be a reinterpet.
167+ const Reinterpret *reinterp_a = add->args [0 ].as <Reinterpret>();
168+ internal_assert (reinterp_a) << " Failed: " << add->args [0 ] << " \n " ;
169+ return rounding_shift (reinterp_a->value , b);
170+ }
171+ }
172+ }
173+ }
174+
148175 // If it wasn't a widening or saturating add, we might still
149176 // be able to safely accept the rounding.
150177 Expr a_less_round = find_and_subtract (a, round);
@@ -199,6 +226,53 @@ class FindIntrinsics : public IRMutator {
199226 }
200227 }
201228
229+ if (op->type .is_int_or_uint () && op->type .bits () > 8 ) {
230+ // Look for widen_right_add intrinsics.
231+ // Yes we do an duplicate code, but we want to check the op->type.code() first,
232+ // and the opposite as well.
233+ for (halide_type_code_t code : {op->type .code (), halide_type_uint, halide_type_int}) {
234+ Type narrow = op->type .narrow ().with_code (code);
235+ // Pulling casts out of VectorReduce nodes breaks too much codegen, skip for now.
236+ Expr narrow_a = (a.node_type () == IRNodeType::VectorReduce) ? Expr () : lossless_cast (narrow, a);
237+ Expr narrow_b = (b.node_type () == IRNodeType::VectorReduce) ? Expr () : lossless_cast (narrow, b);
238+
239+ // This case should have been handled by the above check for widening_add.
240+ internal_assert (!(narrow_a.defined () && narrow_b.defined ()))
241+ << " find_intrinsics failed to find a widening_add: " << a << " + " << b << " \n " ;
242+
243+ if (narrow_a.defined ()) {
244+ Expr result;
245+ if (b.type ().code () != narrow_a.type ().code ()) {
246+ // Need to do a safe reinterpret.
247+ Type t = b.type ().with_code (code);
248+ result = widen_right_add (reinterpret (t, b), narrow_a);
249+ internal_assert (result.type () != op->type );
250+ result = reinterpret (op->type , result);
251+ } else {
252+ result = widen_right_add (b, narrow_a);
253+ }
254+ internal_assert (result.type () == op->type );
255+ return result;
256+ } else if (narrow_b.defined ()) {
257+ Expr result;
258+ if (a.type ().code () != narrow_b.type ().code ()) {
259+ // Need to do a safe reinterpret.
260+ Type t = a.type ().with_code (code);
261+ result = widen_right_add (reinterpret (t, a), narrow_b);
262+ internal_assert (result.type () != op->type );
263+ result = reinterpret (op->type , result);
264+ } else {
265+ result = widen_right_add (a, narrow_b);
266+ }
267+ internal_assert (result.type () == op->type );
268+ return mutate (result);
269+ }
270+ }
271+ }
272+
273+ // TODO: there can be widen_right_add + widen_right_add simplification rules.
274+ // i.e. widen_right_add(a, b) + widen_right_add(c, d) = (a + c) + widening_add(b, d)
275+
202276 if (a.same_as (op->a ) && b.same_as (op->b )) {
203277 return op;
204278 } else {
@@ -240,6 +314,32 @@ class FindIntrinsics : public IRMutator {
240314 return Add::make (a, negative_b);
241315 }
242316
317+ // Run after the lossless_negate check, because we want that to turn into an widen_right_add if relevant.
318+ if (op->type .is_int_or_uint () && op->type .bits () > 8 ) {
319+ // Look for widen_right_sub intrinsics.
320+ // Yes we do an duplicate code, but we want to check the op->type.code() first,
321+ // and the opposite as well.
322+ for (halide_type_code_t code : {op->type .code (), halide_type_uint, halide_type_int}) {
323+ Type narrow = op->type .narrow ().with_code (code);
324+ Expr narrow_b = lossless_cast (narrow, b);
325+
326+ if (narrow_b.defined ()) {
327+ Expr result;
328+ if (a.type ().code () != narrow_b.type ().code ()) {
329+ // Need to do a safe reinterpret.
330+ Type t = a.type ().with_code (code);
331+ result = widen_right_sub (reinterpret (t, a), narrow_b);
332+ internal_assert (result.type () != op->type );
333+ result = reinterpret (op->type , result);
334+ } else {
335+ result = widen_right_sub (a, narrow_b);
336+ }
337+ internal_assert (result.type () == op->type );
338+ return mutate (result);
339+ }
340+ }
341+ }
342+
243343 if (a.same_as (op->a ) && b.same_as (op->b )) {
244344 return op;
245345 } else {
@@ -292,6 +392,49 @@ class FindIntrinsics : public IRMutator {
292392 return mutate (result);
293393 }
294394
395+ if (op->type .is_int_or_uint () && op->type .bits () > 8 ) {
396+ // Look for widen_right_mul intrinsics.
397+ // Yes we do an duplicate code, but we want to check the op->type.code() first,
398+ // and the opposite as well.
399+ for (halide_type_code_t code : {op->type .code (), halide_type_uint, halide_type_int}) {
400+ Type narrow = op->type .narrow ().with_code (code);
401+ Expr narrow_a = lossless_cast (narrow, a);
402+ Expr narrow_b = lossless_cast (narrow, b);
403+
404+ // This case should have been handled by the above check for widening_mul.
405+ internal_assert (!(narrow_a.defined () && narrow_b.defined ()))
406+ << " find_intrinsics failed to find a widening_mul: " << a << " + " << b << " \n " ;
407+
408+ if (narrow_a.defined ()) {
409+ Expr result;
410+ if (b.type ().code () != narrow_a.type ().code ()) {
411+ // Need to do a safe reinterpret.
412+ Type t = b.type ().with_code (code);
413+ result = widen_right_mul (reinterpret (t, b), narrow_a);
414+ internal_assert (result.type () != op->type );
415+ result = reinterpret (op->type , result);
416+ } else {
417+ result = widen_right_mul (b, narrow_a);
418+ }
419+ internal_assert (result.type () == op->type );
420+ return result;
421+ } else if (narrow_b.defined ()) {
422+ Expr result;
423+ if (a.type ().code () != narrow_b.type ().code ()) {
424+ // Need to do a safe reinterpret.
425+ Type t = a.type ().with_code (code);
426+ result = widen_right_mul (reinterpret (t, a), narrow_b);
427+ internal_assert (result.type () != op->type );
428+ result = reinterpret (op->type , result);
429+ } else {
430+ result = widen_right_mul (a, narrow_b);
431+ }
432+ internal_assert (result.type () == op->type );
433+ return mutate (result);
434+ }
435+ }
436+ }
437+
295438 if (a.same_as (op->a ) && b.same_as (op->b )) {
296439 return op;
297440 } else {
@@ -594,6 +737,37 @@ class FindIntrinsics : public IRMutator {
594737 const auto is_x_wider_opposite_int = (op->type .is_int () && is_uint (x, 2 * bits)) || (op->type .is_uint () && is_int (x, 2 * bits));
595738
596739 if (
740+ // Simplify extending patterns.
741+ // (x + widen(y)) + widen(z) = x + widening_add(y, z).
742+ rewrite (widen_right_add (widen_right_add (x, y), z),
743+ x + widening_add (y, z),
744+ // We only care about integers, this should be trivially true.
745+ is_x_same_int_or_uint) ||
746+
747+ // (x - widen(y)) - widen(z) = x - widening_add(y, z).
748+ rewrite (widen_right_sub (widen_right_sub (x, y), z),
749+ x - widening_add (y, z),
750+ // We only care about integers, this should be trivially true.
751+ is_x_same_int_or_uint) ||
752+
753+ // (x + widen(y)) - widen(z) = x + cast(t, widening_sub(y, z))
754+ // cast (reinterpret) is needed only for uints.
755+ rewrite (widen_right_sub (widen_right_add (x, y), z),
756+ x + widening_sub (y, z),
757+ is_x_same_int) ||
758+ rewrite (widen_right_sub (widen_right_add (x, y), z),
759+ x + cast (op->type , widening_sub (y, z)),
760+ is_x_same_uint) ||
761+
762+ // (x - widen(y)) + widen(z) = x + cast(t, widening_sub(z, y))
763+ // cast (reinterpret) is needed only for uints.
764+ rewrite (widen_right_add (widen_right_sub (x, y), z),
765+ x + widening_sub (z, y),
766+ is_x_same_int) ||
767+ rewrite (widen_right_add (widen_right_sub (x, y), z),
768+ x + cast (op->type , widening_sub (z, y)),
769+ is_x_same_uint) ||
770+
597771 // Saturating patterns.
598772 rewrite (saturating_cast (op->type , widening_add (x, y)),
599773 saturating_add (x, y),
@@ -679,6 +853,7 @@ class FindIntrinsics : public IRMutator {
679853 }
680854 }
681855 }
856+ // TODO: do we want versions of widen_right_add here?
682857
683858 if (op->is_intrinsic (Call::shift_right) || op->is_intrinsic (Call::shift_left)) {
684859 // Try to turn this into a widening shift.
@@ -885,6 +1060,18 @@ Expr find_intrinsics(const Expr &e) {
8851060 return expr;
8861061}
8871062
1063+ Expr lower_widen_right_add (const Expr &a, const Expr &b) {
1064+ return a + widen (b);
1065+ }
1066+
1067+ Expr lower_widen_right_mul (const Expr &a, const Expr &b) {
1068+ return a * widen (b);
1069+ }
1070+
1071+ Expr lower_widen_right_sub (const Expr &a, const Expr &b) {
1072+ return a - widen (b);
1073+ }
1074+
8881075Expr lower_widening_add (const Expr &a, const Expr &b) {
8891076 return widen (a) + widen (b);
8901077}
@@ -1100,7 +1287,16 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q)
11001287}
11011288
11021289Expr lower_intrinsic (const Call *op) {
1103- if (op->is_intrinsic (Call::widening_add)) {
1290+ if (op->is_intrinsic (Call::widen_right_add)) {
1291+ internal_assert (op->args .size () == 2 );
1292+ return lower_widen_right_add (op->args [0 ], op->args [1 ]);
1293+ } else if (op->is_intrinsic (Call::widen_right_mul)) {
1294+ internal_assert (op->args .size () == 2 );
1295+ return lower_widen_right_mul (op->args [0 ], op->args [1 ]);
1296+ } else if (op->is_intrinsic (Call::widen_right_sub)) {
1297+ internal_assert (op->args .size () == 2 );
1298+ return lower_widen_right_sub (op->args [0 ], op->args [1 ]);
1299+ } else if (op->is_intrinsic (Call::widening_add)) {
11041300 internal_assert (op->args .size () == 2 );
11051301 return lower_widening_add (op->args [0 ], op->args [1 ]);
11061302 } else if (op->is_intrinsic (Call::widening_mul)) {
0 commit comments