Skip to content

Commit

Permalink
Assignment is not associative (#7894)
Browse files Browse the repository at this point in the history
* Assignment is not associative

* Fix internal tests
  • Loading branch information
abadams authored Oct 17, 2023
1 parent f9b90cb commit 5c97c3c
Showing 1 changed file with 18 additions and 39 deletions.
57 changes: 18 additions & 39 deletions src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,12 @@ bool extract_associative_op(const vector<Expr> &exprs, const vector<string> &op_
if (exprs.size() == 1) {
Type t = exprs[0].type();
if (!x_parts[0].defined()) {
// Update with no self-recurrence is associative and the identity
// can be anything since it's going to be replaced anyway, but it's
// not commutative
assoc_op.pattern.ops[0] = Variable::make(t, op_y_names[0]);
assoc_op.pattern.identities[0] = make_const(t, 0);
assoc_op.pattern.is_commutative = false;
assoc_op.xs[0] = {"", Expr()};
assoc_op.ys[0] = {op_y_names[0], exprs[0]};
return true;
// An update that just assigns some value is not associative,
// because there's no good identity. An identity is necessary
// because things like rfactor will combine the identity with
// partially-computed values and expect it to do nothing. For an
// example, see https://github.com/halide/Halide/issues/7893
return false;
} else if (equal(exprs[0], Variable::make(t, op_x_names[0]))) {
// Self assignment, f(x) = f(x), is both associative
// and commutative. The identity can be anything since it's
Expand Down Expand Up @@ -657,14 +654,6 @@ void associativity_test() {
{Replacement("y", max(g_call_0, -3))},
true));

// f(x) = min(4, g(rx)) -> trivially associative
check_associativity("f", {x}, {min(4, g_call_0)},
AssociativeOp(
AssociativePattern(y, make_const(t, 0), true),
{Replacement("", Expr())},
{Replacement("y", min(g_call_0, 4))},
true));

// f(x) = max(max(min(f(x), g(rx) + 2), f(x)), g(rx) + 2) -> can be simplified into max(f(x), g(rx) + 2)
check_associativity("f", {x}, {max(max(min(f_call_0, g_call_0 + 2), f_call_0), g_call_0 + 2)},
AssociativeOp(
Expand Down Expand Up @@ -705,24 +694,14 @@ void associativity_test() {
Expr g_call_0 = Call::make(ts[0], "g", {rx}, Call::CallType::Halide, FunctionPtr(), 0);
Expr g_call_1 = Call::make(ts[1], "g", {rx}, Call::CallType::Halide, FunctionPtr(), 1);

// f(x) = Tuple(f(x)[0], 3, f(x)[2] + z)
check_associativity("f", {x}, {f_call_0, make_const(ts[1], 3), f_call_2 + cast(ts[2], z)},
// f(x) = Tuple(f(x)[0], f(x)[2] + z)
check_associativity("f", {x}, {f_call_0, f_call_1 + cast(ts[1], z)},
AssociativeOp(
AssociativePattern({xs[0], ys[1], xs[2] + ys[2]},
{make_const(ts[0], 0), make_const(ts[1], 0), make_const(ts[2], 0)},
AssociativePattern({xs[0], xs[1] + ys[1]},
{make_const(ts[0], 0), make_const(ts[1], 0)},
true),
{Replacement("x0", f_call_0), Replacement("", Expr()), Replacement("x2", f_call_2)},
{Replacement("", Expr()), Replacement("y1", make_const(ts[1], 3)), Replacement("y2", cast(ts[2], z))},
true));

// f(x) = Tuple(2, 3, f(x)[2] + z)
check_associativity("f", {x}, {make_const(ts[0], 2), make_const(ts[1], 3), f_call_2 + cast(ts[2], z)},
AssociativeOp(
AssociativePattern({ys[0], ys[1], xs[2] + ys[2]},
{make_const(ts[0], 0), make_const(ts[1], 0), make_const(ts[2], 0)},
true),
{Replacement("", Expr()), Replacement("", Expr()), Replacement("x2", f_call_2)},
{Replacement("y0", make_const(ts[0], 2)), Replacement("y1", make_const(ts[1], 3)), Replacement("y2", cast(ts[2], z))},
{Replacement("x0", f_call_0), Replacement("x1", f_call_1)},
{Replacement("", Expr()), Replacement("y1", cast(ts[1], z))},
true));

// f(x) = Tuple(min(f(x)[0], g(rx)), f(x)[1]*g(x)*2, f(x)[2] + z)
Expand Down Expand Up @@ -780,24 +759,24 @@ void associativity_test() {
Expr f_xy_call_3 = Call::make(ts[3], "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 3);
Expr g_xy_call_0 = Call::make(ts[0], "g", {rx, ry}, Call::CallType::Halide, FunctionPtr(), 0);

// 2D argmin + trivial update (with mixed types):
// 2D argmin + sum
// f(x, y) = Tuple(min(f(x, y)[0], g(r.x, r.y)[0]),
// r.x + r.y,
// f(x, y)[1] + r.x,
// select(f(x, y)[0] < g(r.x, r.y)[0], f(x)[2], r.x),
// select(f(x, y)[0] < g(r.x, r.y)[0], f(x)[3], r.y))
check_associativity("f", {x, y},
{min(f_xy_call_0, g_xy_call_0),
rx + ry,
f_xy_call_1 + rx,
select(f_xy_call_0 < g_xy_call_0, f_xy_call_2, cast(Int(16), rx)),
select(f_xy_call_0 < g_xy_call_0, f_xy_call_3, cast(Float(32), ry))},
AssociativeOp(
AssociativePattern(
{min(xs[0], ys[0]), ys[1], select(xs[0] < ys[0], xs[2], ys[2]), select(xs[0] < ys[0], xs[3], ys[3])},
{min(xs[0], ys[0]), xs[1] + ys[1], select(xs[0] < ys[0], xs[2], ys[2]), select(xs[0] < ys[0], xs[3], ys[3])},
{ts[0].max(), make_const(ts[1], 0), make_const(ts[2], 0), make_const(ts[3], 0)},
true),
{Replacement("x0", f_xy_call_0), Replacement("", Expr()),
{Replacement("x0", f_xy_call_0), Replacement("x1", f_xy_call_1),
Replacement("x2", f_xy_call_2), Replacement("x3", f_xy_call_3)},
{Replacement("y0", g_xy_call_0), Replacement("y1", rx + ry),
{Replacement("y0", g_xy_call_0), Replacement("y1", rx),
Replacement("y2", cast(Int(16), rx)), Replacement("y3", cast(Float(32), ry))},
true));
}
Expand Down

0 comments on commit 5c97c3c

Please sign in to comment.