diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 09b4aed1036d..fd36c6e4fbd2 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -558,10 +558,10 @@ Stmt Simplify::visit(const Block *op) { equal(if_first->condition, if_next->condition) && is_pure(if_first->condition)) { // Two ifs with matching conditions. - Stmt then_case = mutate(Block::make(if_first->then_case, if_next->then_case)); + Stmt then_case = Block::make(if_first->then_case, if_next->then_case); Stmt else_case; if (if_first->else_case.defined() && if_next->else_case.defined()) { - else_case = mutate(Block::make(if_first->else_case, if_next->else_case)); + else_case = Block::make(if_first->else_case, if_next->else_case); } else if (if_first->else_case.defined()) { // We already simplified the body of the ifs. else_case = if_first->else_case; @@ -572,7 +572,9 @@ Stmt Simplify::visit(const Block *op) { if (if_rest.defined()) { result = Block::make(result, if_rest); } - return result; + // We must mutate the entire IfThenElse block without first mutating the + // branches to compute reachability accurately. + return mutate(result); } else if (if_first && if_next && !if_next->else_case.defined() && @@ -583,13 +585,15 @@ Stmt Simplify::visit(const Block *op) { // the first condition. The second if can be nested // inside the first one, because if it's true the // first one must also be true. - Stmt then_case = mutate(Block::make(if_first->then_case, if_next)); + Stmt then_case = Block::make(if_first->then_case, if_next); Stmt else_case = if_first->else_case; Stmt result = IfThenElse::make(if_first->condition, then_case, else_case); if (if_rest.defined()) { result = Block::make(result, if_rest); } - return result; + // As above, we must mutate the entire IfThenElse block without first + // mutating the branches to compute reachability accurately. + return mutate(result); } else if (if_first && if_next && is_pure(if_first->condition) && @@ -608,7 +612,7 @@ Stmt Simplify::visit(const Block *op) { if (if_rest.defined()) { result = Block::make(result, if_rest); } - return result; + return mutate(result); } else if (op->first.same_as(first) && op->rest.same_as(rest)) { return op; diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index d5a2a664fec5..adc90ac74a0e 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -96,6 +96,28 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/7892 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + RVar yryf; + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5, "rdom_r"); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + Var xo, xi, xoo, xoi, yo, yi; + local_sum.vectorize(x) + .split(x, xo, xi, 2, TailStrategy::PredicateStores) + .split(xo, xoo, xoi, 4, TailStrategy::RoundUp) + .unroll(xoi); + local_sum.update(0).unscheduled(); + Pipeline p({blurry}); + Buffer buf = p.realize({32, 32}); + } + printf("Success!\n"); return 0;