Skip to content

Commit a2f65ce

Browse files
rootjalexsteven-johnson
authored andcommitted
Add one-sided widening intrinsics. (halide#6967)
* implement widen_right_ ops * update HVX patterns with one-sided widening intrinsics * remove unused HVX pattern flags * strengthen logic for finding rounding shifts Co-authored-by: Steven Johnson <srj@google.com>
1 parent 780e361 commit a2f65ce

File tree

12 files changed

+373
-84
lines changed

12 files changed

+373
-84
lines changed

src/CodeGen_ARM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,7 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init
13311331
if (narrow.defined()) {
13321332
if (init.defined() && target.bits == 32) {
13331333
// On 32-bit, we have an intrinsic for widening add-accumulate.
1334+
// TODO: this could be written as a pattern with widen_right_add (#6951).
13341335
intrin = "pairwise_widening_add_accumulate";
13351336
intrin_args = {accumulator, narrow};
13361337
accumulator = Expr();

src/FindIntrinsics.cpp

Lines changed: 202 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ bool is_safe_for_add(const Expr &e, int max_depth) {
7474
} else if (cast->type.bits() == cast->value.type().bits()) {
7575
return is_safe_for_add(cast->value, max_depth);
7676
}
77-
} else if (Call::as_intrinsic(e, {Call::widening_add, Call::widening_sub})) {
77+
} else if (Call::as_intrinsic(e, {Call::widening_add, Call::widening_sub, Call::widen_right_add, Call::widen_right_sub})) {
7878
return true;
7979
}
8080
return false;
@@ -131,20 +131,47 @@ Expr to_rounding_shift(const Call *c) {
131131
}
132132
Expr round;
133133
if (c->is_intrinsic(Call::shift_right)) {
134-
round = simplify((make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2);
134+
round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2;
135135
} else {
136-
round = simplify((make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2);
136+
round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2;
137137
}
138+
// Input expressions are simplified before running find_intrinsics, but b
139+
// has been lifted here so we need to lower_intrinsics before simplifying
140+
// and re-lifting. Should we move this code into the FindIntrinsics class
141+
// to make it easier to lift round?
142+
round = lower_intrinsics(round);
143+
round = simplify(round);
144+
round = find_intrinsics(round);
138145

139146
// We can always handle widening adds.
140147
if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) {
141-
if (can_prove(lower_intrinsics(add->args[0]) == round)) {
148+
if (can_prove(lower_intrinsics(add->args[0] == round))) {
142149
return rounding_shift(cast(add->type, add->args[1]), b);
143-
} else if (can_prove(lower_intrinsics(add->args[1]) == round)) {
150+
} else if (can_prove(lower_intrinsics(add->args[1] == round))) {
144151
return rounding_shift(cast(add->type, add->args[0]), b);
145152
}
146153
}
147154

155+
if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) {
156+
if (can_prove(lower_intrinsics(add->args[1] == round))) {
157+
return rounding_shift(cast(add->type, add->args[0]), b);
158+
}
159+
}
160+
// Also need to handle the annoying case of a reinterpret wrapping a widen_right_add
161+
// TODO: this pattern makes me want to change the semantics of this op.
162+
if (const Reinterpret *reinterp = a.as<Reinterpret>()) {
163+
if (reinterp->type.bits() == reinterp->value.type().bits()) {
164+
if (const Call *add = Call::as_intrinsic(reinterp->value, {Call::widen_right_add})) {
165+
if (can_prove(lower_intrinsics(add->args[1] == round))) {
166+
// We expect the first operand to be a reinterpet.
167+
const Reinterpret *reinterp_a = add->args[0].as<Reinterpret>();
168+
internal_assert(reinterp_a) << "Failed: " << add->args[0] << "\n";
169+
return rounding_shift(reinterp_a->value, b);
170+
}
171+
}
172+
}
173+
}
174+
148175
// If it wasn't a widening or saturating add, we might still
149176
// be able to safely accept the rounding.
150177
Expr a_less_round = find_and_subtract(a, round);
@@ -199,6 +226,53 @@ class FindIntrinsics : public IRMutator {
199226
}
200227
}
201228

229+
if (op->type.is_int_or_uint() && op->type.bits() > 8) {
230+
// Look for widen_right_add intrinsics.
231+
// Yes we do an duplicate code, but we want to check the op->type.code() first,
232+
// and the opposite as well.
233+
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
234+
Type narrow = op->type.narrow().with_code(code);
235+
// Pulling casts out of VectorReduce nodes breaks too much codegen, skip for now.
236+
Expr narrow_a = (a.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, a);
237+
Expr narrow_b = (b.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, b);
238+
239+
// This case should have been handled by the above check for widening_add.
240+
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
241+
<< "find_intrinsics failed to find a widening_add: " << a << " + " << b << "\n";
242+
243+
if (narrow_a.defined()) {
244+
Expr result;
245+
if (b.type().code() != narrow_a.type().code()) {
246+
// Need to do a safe reinterpret.
247+
Type t = b.type().with_code(code);
248+
result = widen_right_add(reinterpret(t, b), narrow_a);
249+
internal_assert(result.type() != op->type);
250+
result = reinterpret(op->type, result);
251+
} else {
252+
result = widen_right_add(b, narrow_a);
253+
}
254+
internal_assert(result.type() == op->type);
255+
return result;
256+
} else if (narrow_b.defined()) {
257+
Expr result;
258+
if (a.type().code() != narrow_b.type().code()) {
259+
// Need to do a safe reinterpret.
260+
Type t = a.type().with_code(code);
261+
result = widen_right_add(reinterpret(t, a), narrow_b);
262+
internal_assert(result.type() != op->type);
263+
result = reinterpret(op->type, result);
264+
} else {
265+
result = widen_right_add(a, narrow_b);
266+
}
267+
internal_assert(result.type() == op->type);
268+
return mutate(result);
269+
}
270+
}
271+
}
272+
273+
// TODO: there can be widen_right_add + widen_right_add simplification rules.
274+
// i.e. widen_right_add(a, b) + widen_right_add(c, d) = (a + c) + widening_add(b, d)
275+
202276
if (a.same_as(op->a) && b.same_as(op->b)) {
203277
return op;
204278
} else {
@@ -240,6 +314,32 @@ class FindIntrinsics : public IRMutator {
240314
return Add::make(a, negative_b);
241315
}
242316

317+
// Run after the lossless_negate check, because we want that to turn into an widen_right_add if relevant.
318+
if (op->type.is_int_or_uint() && op->type.bits() > 8) {
319+
// Look for widen_right_sub intrinsics.
320+
// Yes we do an duplicate code, but we want to check the op->type.code() first,
321+
// and the opposite as well.
322+
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
323+
Type narrow = op->type.narrow().with_code(code);
324+
Expr narrow_b = lossless_cast(narrow, b);
325+
326+
if (narrow_b.defined()) {
327+
Expr result;
328+
if (a.type().code() != narrow_b.type().code()) {
329+
// Need to do a safe reinterpret.
330+
Type t = a.type().with_code(code);
331+
result = widen_right_sub(reinterpret(t, a), narrow_b);
332+
internal_assert(result.type() != op->type);
333+
result = reinterpret(op->type, result);
334+
} else {
335+
result = widen_right_sub(a, narrow_b);
336+
}
337+
internal_assert(result.type() == op->type);
338+
return mutate(result);
339+
}
340+
}
341+
}
342+
243343
if (a.same_as(op->a) && b.same_as(op->b)) {
244344
return op;
245345
} else {
@@ -292,6 +392,49 @@ class FindIntrinsics : public IRMutator {
292392
return mutate(result);
293393
}
294394

395+
if (op->type.is_int_or_uint() && op->type.bits() > 8) {
396+
// Look for widen_right_mul intrinsics.
397+
// Yes we do an duplicate code, but we want to check the op->type.code() first,
398+
// and the opposite as well.
399+
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
400+
Type narrow = op->type.narrow().with_code(code);
401+
Expr narrow_a = lossless_cast(narrow, a);
402+
Expr narrow_b = lossless_cast(narrow, b);
403+
404+
// This case should have been handled by the above check for widening_mul.
405+
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
406+
<< "find_intrinsics failed to find a widening_mul: " << a << " + " << b << "\n";
407+
408+
if (narrow_a.defined()) {
409+
Expr result;
410+
if (b.type().code() != narrow_a.type().code()) {
411+
// Need to do a safe reinterpret.
412+
Type t = b.type().with_code(code);
413+
result = widen_right_mul(reinterpret(t, b), narrow_a);
414+
internal_assert(result.type() != op->type);
415+
result = reinterpret(op->type, result);
416+
} else {
417+
result = widen_right_mul(b, narrow_a);
418+
}
419+
internal_assert(result.type() == op->type);
420+
return result;
421+
} else if (narrow_b.defined()) {
422+
Expr result;
423+
if (a.type().code() != narrow_b.type().code()) {
424+
// Need to do a safe reinterpret.
425+
Type t = a.type().with_code(code);
426+
result = widen_right_mul(reinterpret(t, a), narrow_b);
427+
internal_assert(result.type() != op->type);
428+
result = reinterpret(op->type, result);
429+
} else {
430+
result = widen_right_mul(a, narrow_b);
431+
}
432+
internal_assert(result.type() == op->type);
433+
return mutate(result);
434+
}
435+
}
436+
}
437+
295438
if (a.same_as(op->a) && b.same_as(op->b)) {
296439
return op;
297440
} else {
@@ -594,6 +737,37 @@ class FindIntrinsics : public IRMutator {
594737
const auto is_x_wider_opposite_int = (op->type.is_int() && is_uint(x, 2 * bits)) || (op->type.is_uint() && is_int(x, 2 * bits));
595738

596739
if (
740+
// Simplify extending patterns.
741+
// (x + widen(y)) + widen(z) = x + widening_add(y, z).
742+
rewrite(widen_right_add(widen_right_add(x, y), z),
743+
x + widening_add(y, z),
744+
// We only care about integers, this should be trivially true.
745+
is_x_same_int_or_uint) ||
746+
747+
// (x - widen(y)) - widen(z) = x - widening_add(y, z).
748+
rewrite(widen_right_sub(widen_right_sub(x, y), z),
749+
x - widening_add(y, z),
750+
// We only care about integers, this should be trivially true.
751+
is_x_same_int_or_uint) ||
752+
753+
// (x + widen(y)) - widen(z) = x + cast(t, widening_sub(y, z))
754+
// cast (reinterpret) is needed only for uints.
755+
rewrite(widen_right_sub(widen_right_add(x, y), z),
756+
x + widening_sub(y, z),
757+
is_x_same_int) ||
758+
rewrite(widen_right_sub(widen_right_add(x, y), z),
759+
x + cast(op->type, widening_sub(y, z)),
760+
is_x_same_uint) ||
761+
762+
// (x - widen(y)) + widen(z) = x + cast(t, widening_sub(z, y))
763+
// cast (reinterpret) is needed only for uints.
764+
rewrite(widen_right_add(widen_right_sub(x, y), z),
765+
x + widening_sub(z, y),
766+
is_x_same_int) ||
767+
rewrite(widen_right_add(widen_right_sub(x, y), z),
768+
x + cast(op->type, widening_sub(z, y)),
769+
is_x_same_uint) ||
770+
597771
// Saturating patterns.
598772
rewrite(saturating_cast(op->type, widening_add(x, y)),
599773
saturating_add(x, y),
@@ -679,6 +853,7 @@ class FindIntrinsics : public IRMutator {
679853
}
680854
}
681855
}
856+
// TODO: do we want versions of widen_right_add here?
682857

683858
if (op->is_intrinsic(Call::shift_right) || op->is_intrinsic(Call::shift_left)) {
684859
// Try to turn this into a widening shift.
@@ -885,6 +1060,18 @@ Expr find_intrinsics(const Expr &e) {
8851060
return expr;
8861061
}
8871062

1063+
Expr lower_widen_right_add(const Expr &a, const Expr &b) {
1064+
return a + widen(b);
1065+
}
1066+
1067+
Expr lower_widen_right_mul(const Expr &a, const Expr &b) {
1068+
return a * widen(b);
1069+
}
1070+
1071+
Expr lower_widen_right_sub(const Expr &a, const Expr &b) {
1072+
return a - widen(b);
1073+
}
1074+
8881075
Expr lower_widening_add(const Expr &a, const Expr &b) {
8891076
return widen(a) + widen(b);
8901077
}
@@ -1100,7 +1287,16 @@ Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q)
11001287
}
11011288

11021289
Expr lower_intrinsic(const Call *op) {
1103-
if (op->is_intrinsic(Call::widening_add)) {
1290+
if (op->is_intrinsic(Call::widen_right_add)) {
1291+
internal_assert(op->args.size() == 2);
1292+
return lower_widen_right_add(op->args[0], op->args[1]);
1293+
} else if (op->is_intrinsic(Call::widen_right_mul)) {
1294+
internal_assert(op->args.size() == 2);
1295+
return lower_widen_right_mul(op->args[0], op->args[1]);
1296+
} else if (op->is_intrinsic(Call::widen_right_sub)) {
1297+
internal_assert(op->args.size() == 2);
1298+
return lower_widen_right_sub(op->args[0], op->args[1]);
1299+
} else if (op->is_intrinsic(Call::widening_add)) {
11041300
internal_assert(op->args.size() == 2);
11051301
return lower_widening_add(op->args[0], op->args[1]);
11061302
} else if (op->is_intrinsic(Call::widening_mul)) {

src/FindIntrinsics.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ namespace Halide {
1111
namespace Internal {
1212

1313
/** Implement intrinsics with non-intrinsic using equivalents. */
14+
Expr lower_widen_right_add(const Expr &a, const Expr &b);
15+
Expr lower_widen_right_mul(const Expr &a, const Expr &b);
16+
Expr lower_widen_right_sub(const Expr &a, const Expr &b);
1417
Expr lower_widening_add(const Expr &a, const Expr &b);
1518
Expr lower_widening_mul(const Expr &a, const Expr &b);
1619
Expr lower_widening_sub(const Expr &a, const Expr &b);

0 commit comments

Comments
 (0)