diff --git a/src/IR.cpp b/src/IR.cpp index bf53aaac476c..bc1581205657 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -669,6 +669,7 @@ const char *const intrinsic_op_names[] = { "shift_right", "signed_integer_overflow", "size_of_halide_buffer_t", + "sliding_window_marker", "sorted_avg", "strict_float", "stringify", diff --git a/src/IR.h b/src/IR.h index ba4fcb09b587..736e43fc9b20 100644 --- a/src/IR.h +++ b/src/IR.h @@ -593,6 +593,15 @@ struct Call : public ExprNode { signed_integer_overflow, size_of_halide_buffer_t, + // Takes a realization name and a loop variable. Declares that values of + // the realization that were stored on earlier loop iterations of the + // given loop are potentially loaded in this loop iteration somewhere + // after this point. Must occur inside a Realize node and For node of + // the given names but outside any corresponding ProducerConsumer + // nodes. Communicates to storage folding that sliding window took + // place. + sliding_window_marker, + // Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1]. sorted_avg, strict_float, diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 8101d66d3fff..a5cd59544cde 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -816,7 +816,9 @@ class SlidingWindow : public IRMutator { } } - SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dimensions[func.name()]); + set &slid_dims = slid_dimensions[func.name()]; + size_t old_slid_dims_size = slid_dims.size(); + SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dims); body = slider.mutate(body); if (func.schedule().memory_type() == MemoryType::Register && @@ -856,6 +858,15 @@ class SlidingWindow : public IRMutator { new_lets.emplace_front(name + ".loop_min.orig", loop_min); new_lets.emplace_front(name + ".loop_extent", (loop_max - loop_min) + 1); } + + if (slid_dims.size() > old_slid_dims_size) { + // Let storage folding know there's now a read-after-write hazard here + Expr marker = Call::make(Int(32), + Call::sliding_window_marker, + {func.name(), Variable::make(Int(32), op->name)}, + Call::Intrinsic); + body = Block::make(Evaluate::make(marker), body); + } } body = mutate(body); diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index d24ab58dfd95..9010e15595b3 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -496,6 +496,28 @@ class AttemptStorageFoldingOfFunction : public IRMutator { } } + bool found_sliding_marker = false; + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::sliding_window_marker)) { + internal_assert(op->args.size() == 2); + const StringImm *name = op->args[0].as(); + internal_assert(name); + if (name->value == func.name()) { + found_sliding_marker = true; + } + } + return op; + } + + Stmt visit(const Block *op) override { + Stmt first = mutate(op->first); + if (found_sliding_marker) { + return Block::make(first, op->rest); + } else { + return Block::make(first, mutate(op->rest)); + } + } + Stmt visit(const For *op) override { if (op->for_type != ForType::Serial && op->for_type != ForType::Unrolled) { // We can't proceed into a parallel for loop. @@ -878,12 +900,10 @@ class AttemptStorageFoldingOfFunction : public IRMutator { } } - // If there's no communication of values from one loop - // iteration to the next (which may happen due to sliding), - // then we're safe to fold an inner loop. - if (box_contains(provided, required)) { - body = mutate(body); - } + // Attempt to fold an inner loop. This will bail out if it encounters a + // ProducerConsumer node for the func, or if it hits a sliding window + // marker. + body = mutate(body); if (body.same_as(op->body)) { stmt = op; @@ -1010,10 +1030,23 @@ class StorageFolding : public IRMutator { } }; +class RemoveSlidingWindowMarkers : public IRMutator { + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::sliding_window_marker)) { + return make_zero(op->type); + } else { + return IRMutator::visit(op); + } + } +}; + } // namespace Stmt storage_folding(const Stmt &s, const std::map &env) { - return StorageFolding(env).mutate(s); + Stmt stmt = StorageFolding(env).mutate(s); + stmt = RemoveSlidingWindowMarkers().mutate(stmt); + return stmt; } } // namespace Internal diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index 4618da288f23..07f940ed82e3 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -161,7 +161,28 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } - printf("Success!\n"); + // https://github.com/halide/Halide/issues/7909 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + Var yo, yi; + blurry.split(y, yo, yi, 1, TailStrategy::Auto); + local_sum.compute_at(blurry, yo); + local_sum.store_root(); + input.compute_at(local_sum, x); + input.store_root(); + Pipeline p({blurry}); + Buffer buf = p.realize({32, 32}); + check_blur_output(buf, correct); + } + printf("Success!\n"); return 0; }