Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stronger chain detection in LoopCarry pass #8016

Merged
merged 10 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/LoopCarry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,34 @@ class LoopCarryOverLoop : public IRMutator {

// For each load, move the load index forwards by one loop iteration
vector<Expr> indices, next_indices, predicates, next_predicates;
// CSE-d versions of the above, so can_prove can be safely used on them.
vector<Expr> indices_csed, next_indices_csed, predicates_csed, next_predicates_csed;
for (const vector<const Load *> &v : loads) {
indices.push_back(v[0]->index);
next_indices.push_back(step_forwards(v[0]->index, linear));
predicates.push_back(v[0]->predicate);
next_predicates.push_back(step_forwards(v[0]->predicate, linear));

if (indices.back().defined()) {
indices_csed.push_back(common_subexpression_elimination(indices.back()));
} else {
indices_csed.emplace_back();
}
if (next_indices.back().defined()) {
next_indices_csed.push_back(common_subexpression_elimination(next_indices.back()));
} else {
next_indices_csed.emplace_back();
}
if (predicates.back().defined()) {
predicates_csed.push_back(common_subexpression_elimination(predicates.back()));
} else {
predicates_csed.emplace_back();
}
if (next_predicates.back().defined()) {
next_predicates_csed.push_back(common_subexpression_elimination(next_predicates.back()));
} else {
next_predicates_csed.emplace_back();
}
}

// Find loads done on this loop iteration that will be
Expand All @@ -299,11 +322,16 @@ class LoopCarryOverLoop : public IRMutator {
if (i == j) {
continue;
}
// can_prove is stronger than graph_equal, because it doesn't require index expressions to be
// exactly the same, but evaluate to the same value. We keep the graph_equal check, because
// it's faster and should be executed before the more expensive check.
if (loads[i][0]->name == loads[j][0]->name &&
next_indices[j].defined() &&
graph_equal(indices[i], next_indices[j]) &&
(graph_equal(indices[i], next_indices[j]) ||
((indices[i].type() == next_indices[j].type()) && can_prove(indices_csed[i] == next_indices_csed[j]))) &&
next_predicates[j].defined() &&
graph_equal(predicates[i], next_predicates[j])) {
(graph_equal(predicates[i], next_predicates[j]) ||
((predicates[i].type() == next_predicates[j].type()) && can_prove(predicates_csed[i] == next_predicates_csed[j])))) {
chains.push_back({j, i});
debug(3) << "Found carried value:\n"
<< i << ": -> " << Expr(loads[i][0]) << "\n"
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ tests(GROUPS correctness
likely.cpp
load_library.cpp
logical.cpp
loop_carry.cpp
loop_invariant_extern_calls.cpp
loop_level_generator_param.cpp
lossless_cast.cpp
Expand Down
69 changes: 69 additions & 0 deletions test/correctness/loop_carry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "Halide.h"
#include <stdio.h>

// This file demonstrates two example custom lowering passes. The
vksnk marked this conversation as resolved.
Show resolved Hide resolved
// first just makes sure the IR passes some test, and doesn't modify
// it. The second actually changes the IR in some useful way.

using namespace Halide;
using namespace Halide::Internal;

// Verify that all floating point divisions by constants have been
// converted to float multiplication.
class LoopCarryWrapper : public IRMutator {
using IRMutator::visit;

int register_count_;
Stmt mutate(const Stmt &stmt) override {
return simplify(loop_carry(stmt, register_count_));
}

public:
LoopCarryWrapper(int register_count)
: register_count_(register_count) {
}
};

int main(int argc, char **argv) {
Func input;
Func g;
Func h;
Func f;
Var x, y, xo, yo, xi, yi;

input(x, y) = x + y;

Expr sum_expr = 0;
for (int ix = -100; ix <= 100; ix++) {
// Generate two chains of sums, but only one of them will be carried.
sum_expr += input(x, y + ix);
sum_expr += input(x + 13, y + 2 * ix);
}
g(x, y) = sum_expr;
h(x, y) = g(x, y) + 12;
f(x, y) = h(x, y);

// Make a maximum number of the carried values very large for the purpose
// of this test.
constexpr int kMaxRegisterCount = 1024;
f.add_custom_lowering_pass(new LoopCarryWrapper(kMaxRegisterCount));

const int size = 128;
f.compute_root()
.bound(x, 0, size)
.bound(y, 0, size);

h.compute_root()
.tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::RoundUp);

g.compute_at(h, xo)
.reorder(y, x)
.vectorize(x, 4);

input.compute_root();

f.realize({size, size});

printf("Success!\n");
return 0;
}