Skip to content

Commit

Permalink
Fix two compute_with bugs.
Browse files Browse the repository at this point in the history
This PR fixes a bug in compute_with, and another bug I found while
fixing it (we could really use a compute_with fuzzer).

The first bug is that you can get into situations where the bounds of a
producer func will refer directly to the loop variable of a consumer
func, where the consumer is in a compute_with fused group. In main, that
loop variable may not be defined because fused loop names have been
rewritten to include the token ".fused.". This PR adds let stmts to
define it just inside the fused loop body.

The second bug is that not all parent loops in compute_with fused groups
were having their bounds expanded to cover the region to be computed of
all children, because the logic for deciding which loops to expand only
considered the non-specialized pure definition. So e.g. compute_with
applied to an update stage would fail to compute values of the child
Func where they do not overlap with the parent Func. This PR visits all
definitions of the parent Func of the fused group, instead of just the
unspecialized pure definition of the parent Func.

Fixes #8149
  • Loading branch information
abadams committed Mar 12, 2024
1 parent bf0d611 commit 96a283c
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 75 deletions.
224 changes: 150 additions & 74 deletions src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,81 +1021,126 @@ class CollectBounds : public IRVisitor {
}
};

class SubstituteFusedBounds : public IRMutator {
public:
const map<string, Expr> &replacements;
explicit SubstituteFusedBounds(const map<string, Expr> &r)
: replacements(r) {
// Rename a loop var in a compute_with cluster to include '.fused.', to
// disambiguate its bounds from the original loop bounds. The '.fused.' token is
// injected somewhere that's not going to change the results of var_name_match,
// so that it's unchanged as a scheduling point.
string fused_name(const string &var) {
size_t last_dot = var.rfind('.');
internal_assert(last_dot != string::npos);
return var.substr(0, last_dot) + ".fused." + var.substr(last_dot + 1);
}

// The bounds of every loop exist in 'replacements' should be replaced. The
// loop is also renamed by adding '.fused' in the original name before the
// variable name.
Stmt substitute_fused_bounds(Stmt s, const map<string, Expr> &replacements) {
if (!s.defined() || replacements.empty()) {
return s;
}

private:
using IRMutator::visit;
class SubstituteFusedBounds : public IRMutator {
const map<string, Expr> &replacements;

Stmt visit(const For *op) override {
const auto *min_var = op->min.as<Variable>();
const auto *extent_var = op->extent.as<Variable>();
if (min_var && extent_var) {
Expr min_val, extent_val;
{
const auto &it = replacements.find(min_var->name);
if (it != replacements.end()) {
min_val = it->second;
using IRMutator::visit;

Stmt visit(const For *op) override {
const auto *min_var = op->min.as<Variable>();
const auto *extent_var = op->extent.as<Variable>();
if (min_var && extent_var) {
Expr min_val, extent_val;
{
const auto &it = replacements.find(min_var->name);
if (it != replacements.end()) {
min_val = it->second;
}
}
}
{
const auto &it = replacements.find(extent_var->name);
if (it != replacements.end()) {
extent_val = it->second;
{
const auto &it = replacements.find(extent_var->name);
if (it != replacements.end()) {
extent_val = it->second;
}
}
if (!min_val.defined() || !extent_val.defined()) {
return IRMutator::visit(op);
}

Stmt body = mutate(op->body);

string new_var = fused_name(op->name);

ForType for_type = op->for_type;
DeviceAPI device_api = op->device_api;
if (is_const_one(extent_val)) {
// This is the child loop of a fused group. The real loop of the
// fused group is the loop of the parent function of the fused
// group. This child loop is just a scheduling point, and should
// never be a device transition, so we rewrite it to be a simple
// serial loop of extent 1."
for_type = ForType::Serial;
device_api = DeviceAPI::None;
}

Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"),
Variable::make(Int(32), new_var + ".loop_extent"),
for_type, op->partition_policy, device_api, body);

// Add let stmts defining the bound of the renamed for-loop.
stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt);
stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt);
stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt);
// Replace any reference to the old loop name with the new one.
stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt);
return stmt;
} else {
return IRMutator::visit(op);
}
if (!min_val.defined() || !extent_val.defined()) {
}

public:
explicit SubstituteFusedBounds(const map<string, Expr> &r)
: replacements(r) {
}
} subs(replacements);

return subs.mutate(s);
}

// Add letstmts inside each parent loop that define the corresponding child loop
// vars as equal to it. Bounds inference might need a child loop var.
Stmt add_loop_var_aliases(Stmt s, const map<string, set<string>> &loop_var_aliases) {
if (!s.defined() || loop_var_aliases.empty()) {
return s;
}

class AddLoopVarAliases : public IRMutator {
const map<string, set<string>> &loop_var_aliases;

using IRMutator::visit;

Stmt visit(const For *op) override {
auto it = loop_var_aliases.find(op->name);
if (it == loop_var_aliases.end()) {
return IRMutator::visit(op);
}

Expr var = Variable::make(Int(32), op->name);
Stmt body = mutate(op->body);

size_t last_dot = op->name.rfind('.');
internal_assert(last_dot != string::npos);
string new_var = op->name.substr(0, last_dot) + ".fused." + op->name.substr(last_dot + 1);

ForType for_type = op->for_type;
DeviceAPI device_api = op->device_api;
if (is_const_one(extent_val)) {
// This is the child loop of a fused group. The real loop of the
// fused group is the loop of the parent function of the fused
// group. This child loop is just a scheduling point, and should
// never be a device transition, so we rewrite it to be a simple
// serial loop of extent 1."
for_type = ForType::Serial;
device_api = DeviceAPI::None;
for (const string &alias : it->second) {
body = LetStmt::make(alias, var, body);
}

Stmt stmt = For::make(new_var, Variable::make(Int(32), new_var + ".loop_min"),
Variable::make(Int(32), new_var + ".loop_extent"),
for_type, op->partition_policy, device_api, body);
return For::make(op->name, op->min, op->extent, op->for_type,
op->partition_policy, op->device_api, std::move(body));
}

// Add let stmts defining the bound of the renamed for-loop.
stmt = LetStmt::make(new_var + ".loop_min", min_val, stmt);
stmt = LetStmt::make(new_var + ".loop_max", simplify(min_val + extent_val - 1), stmt);
stmt = LetStmt::make(new_var + ".loop_extent", extent_val, stmt);
// Replace any reference to the old loop name with the new one.
stmt = substitute(op->name, Variable::make(Int(32), new_var), stmt);
return stmt;
} else {
return IRMutator::visit(op);
public:
explicit AddLoopVarAliases(const map<string, set<string>> &a)
: loop_var_aliases(a) {
}
}
};
} add_aliases(loop_var_aliases);

// The bounds of every loop exist in 'replacements' should be replaced. The
// loop is also renamed by adding '.fused' in the original name before the
// variable name.
Stmt substitute_fused_bounds(Stmt s, const map<string, Expr> &replacements) {
if (!s.defined() || replacements.empty()) {
return s;
} else {
return SubstituteFusedBounds(replacements).mutate(s);
}
return add_aliases.mutate(s);
}

// Shift the iteration domain of a loop nest by some factor.
Expand Down Expand Up @@ -1460,7 +1505,9 @@ class InjectFunctionRealization : public IRMutator {
}

Stmt build_produce_definition(const Function &f, const string &prefix, const Definition &def, bool is_update,
map<string, Expr> &replacements, vector<pair<string, Expr>> &add_lets) {
map<string, Expr> &replacements,
vector<pair<string, Expr>> &add_lets,
map<string, set<string>> &aliases) {
const vector<Dim> &dims = def.schedule().dims(); // From inner to outer
const LoopLevel &fuse_level = def.schedule().fuse_level().level;

Expand Down Expand Up @@ -1499,6 +1546,10 @@ class InjectFunctionRealization : public IRMutator {
replacements.emplace(var + ".loop_extent", make_const(Int(32), 1));
replacements.emplace(var + ".loop_min", val);
replacements.emplace(var + ".loop_max", val);

string var_fused = fused_name(var_orig);
aliases[var_fused].emplace(std::move(var_orig));
aliases[var_fused].emplace(std::move(var));
}
}

Expand Down Expand Up @@ -1550,18 +1601,17 @@ class InjectFunctionRealization : public IRMutator {

// Replace the bounds of the parent fused loop (i.e. the first one to be
// realized in the group) with union of the bounds of the fused group.
Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce, const map<string, Expr> &bounds) {
string prefix = f.name() + ".s0";
const Definition &def = f.definition();
Stmt replace_parent_bound_with_union_bound(string func, int stage,
const Definition &def, Stmt produce,
const map<string, Expr> &bounds,
map<string, Expr> &replacements) {

if (!def.defined()) {
if (def.schedule().fused_pairs().empty()) {
return produce;
}

const vector<Dim> &dims = def.schedule().dims(); // From inner to outer

map<string, Expr> replacements;

vector<FusedPair> dependence = collect_all_dependence(def);

// Compute the union of the bounds of the fused loops.
Expand All @@ -1582,6 +1632,8 @@ class InjectFunctionRealization : public IRMutator {
// the parent, e.g. y.yi and yi.
int dim2_idx = (int)(dims_2.size() - (dims.size() - i));
internal_assert(dim2_idx < (int)dims_2.size());
string var_1 = func + ".s" + std::to_string(stage) +
"." + dims[i].var;

string var_2 = pair.func_2 + ".s" + std::to_string(pair.stage_2) +
"." + dims_2[dim2_idx].var;
Expand All @@ -1592,7 +1644,6 @@ class InjectFunctionRealization : public IRMutator {
Expr max_2 = bounds.find(var_2 + ".loop_max")->second;
Expr extent_2 = bounds.find(var_2 + ".loop_extent")->second;

string var_1 = prefix + "." + dims[i].var;
internal_assert(bounds.count(var_1 + ".loop_min"));
internal_assert(bounds.count(var_1 + ".loop_max"));
internal_assert(bounds.count(var_1 + ".loop_extent"));
Expand All @@ -1616,8 +1667,26 @@ class InjectFunctionRealization : public IRMutator {
}
}

// Now, replace the bounds of the parent fused loops with the union bounds.
// Now, replace the bounds of the parent fused loops with the union
// bounds.
for (auto &spec : def.specializations()) {
produce = replace_parent_bound_with_union_bound(func, stage, spec.definition, produce, bounds, replacements);
}

return produce;
}

Stmt replace_parent_bound_with_union_bound(const Function &f, Stmt produce,
const map<string, Expr> &bounds) {
map<string, Expr> replacements;

int stage = 0;
produce = replace_parent_bound_with_union_bound(f.name(), stage++, f.definition(), produce, bounds, replacements);
for (const Definition &def : f.updates()) {
produce = replace_parent_bound_with_union_bound(f.name(), stage++, def, produce, bounds, replacements);
}
produce = substitute_fused_bounds(produce, replacements);

return produce;
}

Expand Down Expand Up @@ -1748,22 +1817,23 @@ class InjectFunctionRealization : public IRMutator {
Stmt producer;
map<string, Expr> replacements;
vector<pair<string, Expr>> add_lets;
map<string, set<string>> aliases;

for (const auto &func_stage : stage_order) {
const auto &f = func_stage.first;

if (f.has_extern_definition() && (func_stage.second == 0)) {
const Stmt &produceDef = Internal::build_extern_produce(env, f, target);
producer = inject_stmt(producer, produceDef, LoopLevel::inlined().lock());
const Stmt &produce_def = Internal::build_extern_produce(env, f, target);
producer = inject_stmt(producer, produce_def, LoopLevel::inlined().lock());
continue;
}

string def_prefix = f.name() + ".s" + std::to_string(func_stage.second) + ".";
const auto &def = (func_stage.second == 0) ? f.definition() : f.updates()[func_stage.second - 1];

const Stmt &produceDef = build_produce_definition(f, def_prefix, def, func_stage.second > 0,
replacements, add_lets);
producer = inject_stmt(producer, produceDef, def.schedule().fuse_level().level);
const Stmt &produce_def = build_produce_definition(f, def_prefix, def, func_stage.second > 0,
replacements, add_lets, aliases);
producer = inject_stmt(producer, produce_def, def.schedule().fuse_level().level);
}

internal_assert(producer.defined());
Expand Down Expand Up @@ -1799,8 +1869,14 @@ class InjectFunctionRealization : public IRMutator {

// Replace the bounds of parent fused loop with union of bounds of
// the fused loops.
Function group_parent = funcs.back();
producer = replace_parent_bound_with_union_bound(funcs.back(), producer, bounds);

// Define the old loop var names as equal to the corresponding parent
// fused loop var. Bounds inference might refer directly to the original
// loop vars.
producer = add_loop_var_aliases(producer, aliases);

// Add the producer nodes.
for (const auto &i : funcs) {
producer = ProducerConsumer::make_produce(i.name(), producer);
Expand Down
Loading

0 comments on commit 96a283c

Please sign in to comment.