Skip to content

Commit

Permalink
Fix rfactor adding too many pure loops (halide#8086)
Browse files Browse the repository at this point in the history
When you rfactor an update definition, the new update definition must
use all the pure vars of the Func, even though the one you're rfactoring
may not have used them all.

We also want to preserve any scheduling already done to the pure vars,
so we want to preserve the dims list and splits list from the original
definition.

The code accounted for this by checking the dims list for any missing
pure vars and adding them at the end (just before Var::outermost()), but
this didn't account for the fact that they may no longer exist in the
dims list due to splits that didn't reuse the outer name. In these
circumstances we could end up with too many pure loops. E.g. if x has
been split into xo and xi, then the code was adding a loop for x even
though there were already loops for xo and xi, which of course produces
garbage output.

This PR instead just checks which pure vars are actually used in the
update definition up front, and then uses that to tell which ones should
be added.

Fixes halide#7890
  • Loading branch information
abadams authored and ardier committed Mar 3, 2024
1 parent 3fa535b commit a3b70d2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
26 changes: 23 additions & 3 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,17 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
vector<Expr> &args = definition.args();
vector<Expr> &values = definition.values();

// Figure out which pure vars were used in this update definition.
std::set<string> pure_vars_used;
internal_assert(args.size() == dim_vars.size());
for (size_t i = 0; i < args.size(); i++) {
if (const Internal::Variable *var = args[i].as<Variable>()) {
if (var->name == dim_vars[i].name()) {
pure_vars_used.insert(var->name);
}
}
}

// Check whether the operator is associative and determine the operator and
// its identity for each value in the definition if it is a Tuple
const auto &prover_result = prove_associativity(func_name, args, values);
Expand Down Expand Up @@ -1012,16 +1023,20 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {

// Determine the dims of the new update definition

// The new update definition needs all the pure vars of the Func, but the
// one we're rfactoring may not have used them all. Add any missing ones to
// the dims list.

// Add pure Vars from the original init definition to the dims list
// if they are not already in the list
for (const Var &v : dim_vars) {
const auto &iter = std::find_if(dims.begin(), dims.end(),
[&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
if (iter == dims.end()) {
if (!pure_vars_used.count(v.name())) {
Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto};
// Insert it just before Var::outermost
dims.insert(dims.end() - 1, d);
}
}

// Then, we need to remove lifted RVars from the dims list
for (const string &rv : rvars_removed) {
remove(rv);
Expand Down Expand Up @@ -1888,6 +1903,11 @@ Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {

dims_old.swap(dims);

// We're not allowed to reorder Var::outermost inwards (rfactor assumes it's
// the last one).
user_assert(dims.back().var == Var::outermost().name())
<< "Var::outermost() may not be reordered inside any other var.\n";

return *this;
}

Expand Down
25 changes: 25 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,31 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/7890
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
RVar yryf;
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5, "rdom_r");
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);

Var yo, yi, xo, xi, u;
blurry.split(y, yo, yi, 2, TailStrategy::Auto);
local_sum.split(x, xo, xi, 4, TailStrategy::Auto);
local_sum.update(0).split(x, xo, xi, 1, TailStrategy::Auto);
local_sum.update(0).rfactor(r.x, u);
blurry.store_root();
local_sum.compute_root();
Pipeline p({blurry});
auto buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/8054
{
ImageParam input(Float(32), 2, "input");
Expand Down

0 comments on commit a3b70d2

Please sign in to comment.