From 956e84714ac1facc1909a2a3f01944416b72564e Mon Sep 17 00:00:00 2001 From: Lars Date: Wed, 1 May 2024 11:51:52 +0200 Subject: [PATCH] Fix two compute_with bugs. (#8152) * Fix two compute_with bugs. This PR fixes a bug in compute_with, and another bug I found while fixing it (we could really use a compute_with fuzzer). The first bug is that you can get into situations where the bounds of a producer func will refer directly to the loop variable of a consumer func, where the consumer is in a compute_with fused group. In main, that loop variable may not be defined because fused loop names have been rewritten to include the token ".fused.". This PR adds let stmts to define it just inside the fused loop body. The second bug is that not all parent loops in compute_with fused groups were having their bounds expanded to cover the region to be computed of all children, because the logic for deciding which loops to expand only considered the non-specialized pure definition. So e.g. compute_with applied to an update stage would fail to compute values of the child Func where they do not overlap with the parent Func. This PR visits all definitions of the parent Func of the fused group, instead of just the unspecialized pure definition of the parent Func. Fixes #8149 * clang-tidy --- src/ScheduleFunctions.cpp | 228 ++++++++++++++++++++---------- test/correctness/compute_with.cpp | 140 ++++++++++++++++++ 2 files changed, 292 insertions(+), 76 deletions(-) diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index d0c0f9f7534c..60dc1d566a0c 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -1352,81 +1352,126 @@ class CollectBounds : public IRVisitor { } }; -class SubstituteFusedBounds : public IRMutator { -public: - const map &replacements; - explicit SubstituteFusedBounds(const map &r) - : replacements(r) { +// Rename a loop var in a compute_with cluster to include '.fused.', to +// disambiguate its bounds from the original loop bounds. The '.fused.' token is +// injected somewhere that's not going to change the results of var_name_match, +// so that it's unchanged as a scheduling point. +string fused_name(const string &var) { + size_t last_dot = var.rfind('.'); + internal_assert(last_dot != string::npos); + return var.substr(0, last_dot) + ".fused." + var.substr(last_dot + 1); +} + +// The bounds of every loop exist in 'replacements' should be replaced. The +// loop is also renamed by adding '.fused' in the original name before the +// variable name. +Stmt substitute_fused_bounds(Stmt s, const map &replacements) { + if (!s.defined() || replacements.empty()) { + return s; } -private: - using IRMutator::visit; + class SubstituteFusedBounds : public IRMutator { + const map &replacements; - Stmt visit(const For *op) override { - const auto *min_var = op->min.as(); - const auto *extent_var = op->extent.as(); - if (min_var && extent_var) { - Expr min_val, extent_val; - { - const auto &it = replacements.find(min_var->name); - if (it != replacements.end()) { - min_val = it->second; + using IRMutator::visit; + + Stmt visit(const For *op) override { + const auto *min_var = op->min.as(); + const auto *extent_var = op->extent.as(); + if (min_var && extent_var) { + Expr min_val, extent_val; + { + const auto &it = replacements.find(min_var->name); + if (it != replacements.end()) { + min_val = it->second; + } } - } - { - const auto &it = replacements.find(extent_var->name); - if (it != replacements.end()) { - extent_val = it->second; + { + const auto &it = replacements.find(extent_var->name); + if (it != replacements.end()) { + extent_val = it->second; + } + } + if (!min_val.defined() || !extent_val.defined()) { + return IRMutator::visit(op); + } + + Stmt body = mutate(op->body); + + string new_var = fused_name(op->name); + + ForType for_type = op->for_type; + DeviceAPI device_api = op->device_api; + if (is_const_one(extent_val)) { + // This is the child loop of a fused group. The real loop of the + // fused group is the loop of the parent function of the fused + // group. This child loop is just a scheduling point, and should + // never be a device transition, so we rewrite it to be a simple + // serial loop of extent 1." + for_type = ForType::Serial; + device_api = DeviceAPI::None; } + + Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), + Variable::make(Int(32), new_var + ".loop_extent"), + for_type, device_api, body, op->annotations); + + // Add let stmts defining the bound of the renamed for-loop. + stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); + stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); + stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); + // Replace any reference to the old loop name with the new one. + stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); + return stmt; + } else { + return IRMutator::visit(op); } - if (!min_val.defined() || !extent_val.defined()) { + } + + public: + explicit SubstituteFusedBounds(const map &r) + : replacements(r) { + } + } subs(replacements); + + return subs.mutate(s); +} + +// Add letstmts inside each parent loop that define the corresponding child loop +// vars as equal to it. Bounds inference might need a child loop var. +Stmt add_loop_var_aliases(Stmt s, const map> &loop_var_aliases) { + if (!s.defined() || loop_var_aliases.empty()) { + return s; + } + + class AddLoopVarAliases : public IRMutator { + const map> &loop_var_aliases; + + using IRMutator::visit; + + Stmt visit(const For *op) override { + auto it = loop_var_aliases.find(op->name); + if (it == loop_var_aliases.end()) { return IRMutator::visit(op); } + Expr var = Variable::make(Int(32), op->name); Stmt body = mutate(op->body); - - size_t last_dot = op->name.rfind('.'); - internal_assert(last_dot != string::npos); - string new_var = op->name.substr(0, last_dot) + ".fused." + op->name.substr(last_dot + 1); - - ForType for_type = op->for_type; - DeviceAPI device_api = op->device_api; - if (is_const_one(extent_val)) { - // This is the child loop of a fused group. The real loop of the - // fused group is the loop of the parent function of the fused - // group. This child loop is just a scheduling point, and should - // never be a device transition, so we rewrite it to be a simple - // serial loop of extent 1." - for_type = ForType::Serial; - device_api = DeviceAPI::None; + for (const string &alias : it->second) { + body = LetStmt::make(alias, var, body); } - Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"), - Variable::make(Int(32), new_var + ".loop_extent"), - for_type, device_api, body, op->annotations); + return For::make(op->name, op->min, op->extent, op->for_type, + op->device_api, std::move(body), op->annotations); + } - // Add let stmts defining the bound of the renamed for-loop. - stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt); - stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt); - stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt); - // Replace any reference to the old loop name with the new one. - stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt); - return stmt; - } else { - return IRMutator::visit(op); + public: + explicit AddLoopVarAliases(const map> &a) + : loop_var_aliases(a) { } - } -}; + } add_aliases(loop_var_aliases); -// The bounds of every loop exist in 'replacements' should be replaced. The -// loop is also renamed by adding '.fused' in the original name before the -// variable name. -Stmt substitute_fused_bounds(Stmt s, const map &replacements) { - if (!s.defined() || replacements.empty()) { - return s; - } else { - return SubstituteFusedBounds(replacements).mutate(s); - } + return add_aliases.mutate(s); } // Shift the iteration domain of a loop nest by some factor. @@ -1994,7 +2039,9 @@ class InjectFunctionRealization : public IRMutator { } Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update, - map &replacements, vector> &add_lets) { + map &replacements, + vector> &add_lets, + map> &aliases) { const vector &dims = def.schedule().dims(); // From inner to outer const LoopLevel &fuse_level = def.schedule().fuse_level().level; @@ -2033,6 +2080,10 @@ class InjectFunctionRealization : public IRMutator { replacements.emplace(var + ".loop_extent", make_const(Int(32), 1)); replacements.emplace(var + ".loop_min", val); replacements.emplace(var + ".loop_max", val); + + string var_fused = fused_name(var_orig); + aliases[var_fused].emplace(std::move(var_orig)); + aliases[var_fused].emplace(std::move(var)); } } @@ -2086,18 +2137,17 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of the parent fused loop (i.e. the first one to be // realized in the group) with union of the bounds of the fused group. - Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, const map &bounds) { - string prefix = f.name() + ".s0"; - const Definition &def = f.definition(); + Stmt replace_parent_bound_with_union_bound(const string &func, int stage, + const Definition &def, Stmt produce, + const map &bounds, + map &replacements) { - if (!def.defined()) { + if (def.schedule().fused_pairs().empty()) { return produce; } const vector &dims = def.schedule().dims(); // From inner to outer - map replacements; - vector dependence = collect_all_dependence(def); // Compute the union of the bounds of the fused loops. @@ -2118,6 +2168,8 @@ class InjectFunctionRealization : public IRMutator { // the parent, e.g. y.yi and yi. int dim2_idx = (int)(dims_2.size() - (dims.size() - i)); internal_assert(dim2_idx < (int)dims_2.size()); + string var_1 = func + ".s" + std::to_string(stage) + + "." + dims[i].var; string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) + "." + dims_2[dim2_idx].var; @@ -2128,7 +2180,6 @@ class InjectFunctionRealization : public IRMutator { Expr max_2 = bounds.find(var_2 + ".loop_max")->second; Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second; - string var_1 = prefix + "." + dims[i].var; internal_assert(bounds.count(var_1 + ".loop_min")); internal_assert(bounds.count(var_1 + ".loop_max")); internal_assert(bounds.count(var_1 + ".loop_extent")); @@ -2152,8 +2203,26 @@ class InjectFunctionRealization : public IRMutator { } } - // Now, replace the bounds of the parent fused loops with the union bounds. + // Now, replace the bounds of the parent fused loops with the union + // bounds. + for (const auto &spec : def.specializations()) { + produce = replace_parent_bound_with_union_bound(func, stage, spec.definition, produce, bounds, replacements); + } + + return produce; + } + + Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, + const map &bounds) { + map replacements; + + int stage = 0; + produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements); + for (const Definition &def : f.updates()) { + produce = replace_parent_bound_with_union_bound(f.name(), stage++, def, produce, bounds, replacements); + } produce = substitute_fused_bounds(produce, replacements); + return produce; } @@ -2285,22 +2354,23 @@ class InjectFunctionRealization : public IRMutator { Stmt producer; map replacements; vector> add_lets; + map> aliases; for (const auto &func_stage : stage_order) { const auto &f = func_stage.first; if (f.has_extern_definition() && (func_stage.second == 0)) { - const Stmt &produceDef = Internal::build_extern_produce(env, f, target); - producer = inject_stmt(producer, produceDef, LoopLevel::inlined().lock()); + const Stmt &produce_def = Internal::build_extern_produce(env, f, target); + producer = inject_stmt(producer, produce_def, LoopLevel::inlined().lock()); continue; } string def_prefix = f.name() + ".s" + std::to_string(func_stage.second) + "."; const auto &def = (func_stage.second == 0) ? f.definition() : f.updates()[func_stage.second - 1]; - const Stmt &produceDef = build_produce_definition(f, def_prefix, def, func_stage.second > 0, - replacements, add_lets); - producer = inject_stmt(producer, produceDef, def.schedule().fuse_level().level); + const Stmt &produce_def = build_produce_definition(f, def_prefix, def, func_stage.second > 0, + replacements, add_lets, aliases); + producer = inject_stmt(producer, produce_def, def.schedule().fuse_level().level); } internal_assert(producer.defined()); @@ -2328,8 +2398,8 @@ class InjectFunctionRealization : public IRMutator { } // TODO Lars vd Haak: Don't know what to do with shifts yet internal_assert(shifts.empty()); - // Todo neither with replacements - internal_assert(replacements.empty()); + // TODO neither with replacements + // internal_assert(replacements.empty()); // Shift the loops. producer = ShiftLoopNest::apply_shift(shifts, producer); @@ -2341,8 +2411,14 @@ class InjectFunctionRealization : public IRMutator { // Replace the bounds of parent fused loop with union of bounds of // the fused loops. + Function group_parent = funcs.back(); producer = replace_parent_bound_with_union_bound(funcs.back(), producer, bounds); + // Define the old loop var names as equal to the corresponding parent + // fused loop var. Bounds inference might refer directly to the original + // loop vars. + producer = add_loop_var_aliases(producer, aliases); + // Add the producer nodes. for (const auto &i : funcs) { producer = ProducerConsumer::make_produce(i.name(), producer); diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index 02d980ab3001..3176572b0442 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -2160,6 +2160,146 @@ int main(int argc, char **argv) { printf("Running store_at different levels test\n"); if (store_at_different_levels_test() != 0) { return -1; +// Test for the issue described in https://github.com/halide/Halide/issues/8149. +int child_var_dependent_bounds_test() { + Func f{"f"}, g{"g"}; + Var x{"x"}, y{"y"}; + RDom r(0, 10, "r"); + + Func f_inter{"f_inter"}, g_inter{"g_inter"}; + + f_inter(x, y) = x; + f_inter(x, y) += 1; + f(x) = x; + f(x) += f_inter(x, r); + + g_inter(x, y) = x; + g_inter(x, y) += 1; + g(x) = x; + g(x) += g_inter(x, r); + + f_inter.compute_at(f, r); + g_inter.compute_at(f, r); + g.update().compute_with(f.update(), r); + f.update().unscheduled(); + + Pipeline p({f, g}); + + p.compile_jit(); + Buffer f_buf(10), g_buf(10); + + f_buf.set_min(2); + p.realize({f_buf, g_buf}); + f_buf.set_min(0); + + for (int i = 0; i < 10; i++) { + int correct_f = 10 + 11 * (i + 2); + int correct_g = 10 + 11 * i; + if (f_buf(i) != correct_f) { + printf("f(%d) = %d instead of %d\n", i, f_buf(i), correct_f); + } + if (g_buf(i) != correct_g) { + printf("g(%d) = %d instead of %d\n", i, g_buf(i), correct_f); + } + } + + return 0; +} + +int overlapping_updates_test() { + Func f{"f"}, g{"g"}; + Var x{"x"}; + + f(x) = 0; + f(x) += x; + g(x) = 0; + g(x) += x; + + g.update().compute_with(f.update(), x); + f.update().unscheduled(); + + Pipeline p({f, g}); + + p.compile_jit(); + Buffer f_buf(10), g_buf(10); + + f_buf.set_min(2); + p.realize({f_buf, g_buf}); + f_buf.set_min(0); + + for (int i = 0; i < 10; i++) { + int correct_f = i + 2; + int correct_g = i; + if (f_buf(i) != correct_f) { + printf("f(%d) = %d instead of %d\n", i, f_buf(i), correct_f); + return 1; + } + if (g_buf(i) != correct_g) { + printf("g(%d) = %d instead of %d\n", i, g_buf(i), correct_f); + return 1; + } + } + + return 0; +} + +} // namespace + +int main(int argc, char **argv) { + struct Task { + std::string desc; + std::function fn; + }; + + std::vector tasks = { + {"split reorder test", split_test}, + {"fuse test", fuse_test}, + {"multiple fuse group test", multiple_fuse_group_test}, + {"multiple outputs test", multiple_outputs_test}, + {"double split fuse test", double_split_fuse_test}, + {"vectorize test", vectorize_test}, + // + // Note: we are deprecating skipping parts of a fused group in favor of + // cloning funcs in particular stages via a new (clone_)in overload. + // TODO: remove this code when the new clone_in is implemented. + // + // {"some are skipped test", some_are_skipped_test}, + {"rgb to yuv420 test", rgb_yuv420_test}, + {"with specialization test", with_specialization_test}, + {"fuse compute at test", fuse_compute_at_test}, + {"nested compute with test", nested_compute_with_test}, + {"mixed tile factor test", mixed_tile_factor_test}, + // NOTE: disabled because it generates OOB (see #4751 for discussion). + // {"only some are tiled test", only_some_are_tiled_test}, + {"multiple outputs on gpu test", multiple_outputs_on_gpu_test}, + {"multi tile mixed tile factor test", multi_tile_mixed_tile_factor_test}, + {"update stage test", update_stage_test}, + {"update stage2 test", update_stage2_test}, + {"update stage3 test", update_stage3_test}, + {"update stage pairwise test", update_stage_pairwise_test}, + // I think this should work, but there is an overzealous check somewhere. + // {"update stage pairwise zigzag test", update_stage_pairwise_zigzag_test}, + {"update stage diagonal test", update_stage_diagonal_test}, + {"update stage rfactor test", update_stage_rfactor_test}, + {"vectorize inlined test", vectorize_inlined_test}, + {"mismatching splits test", mismatching_splits_test}, + {"different arg number compute_at test", different_arg_num_compute_at_test}, + {"store_at different levels test", store_at_different_levels_test}, + {"rvar bounds test", rvar_bounds_test}, + {"two compute at test", two_compute_at_test}, + {"overlapping updates test", overlapping_updates_test}, + {"child var dependent bounds test", child_var_dependent_bounds_test}, + }; + + using Sharder = Halide::Internal::Test::Sharder; + Sharder sharder; + for (size_t t = 0; t < tasks.size(); t++) { + if (!sharder.should_run(t)) continue; + const auto &task = tasks.at(t); + std::cout << task.desc << "\n"; + if (task.fn() != 0) { + return 1; + } } printf("Success!\n");