Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoist vector slices using rewrite rules #7243

Merged
merged 10 commits into from
Jan 21, 2023
77 changes: 70 additions & 7 deletions src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,69 @@ HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<declty
return {t, pattern_arg(a)};
}

template<typename Vec, typename Base, typename Stride, typename Lanes>
struct SliceOp {
struct pattern_tag {};
Vec vec;
Base base;
Stride stride;
Lanes lanes;

static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;

constexpr static IRNodeType min_node_type = IRNodeType::Shuffle;
constexpr static IRNodeType max_node_type = IRNodeType::Shuffle;
constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;

template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != IRNodeType::Shuffle) {
return false;
}
const Shuffle &v = (const Shuffle &)e;
return v.vectors.size() == 1 &&
vec.template match<bound>(*v.vectors[0].get(), state) &&
base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
lanes.template match<bound | bindings<Vec>::mask | bindings<Base>::mask | bindings<Stride>::mask>(v.type.lanes(), state);
}

HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t base_val, stride_val, lanes_val;
halide_type_t ty;
base.make_folded_const(base_val, ty, state);
int b = (int)base_val.u.i64;
stride.make_folded_const(stride_val, ty, state);
int s = (int)stride_val.u.i64;
lanes.make_folded_const(lanes_val, ty, state);
int l = (int)lanes_val.u.i64;
return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
}

constexpr static bool foldable = false;

HALIDE_ALWAYS_INLINE
SliceOp(Vec v, Base b, Stride s, Lanes l)
: vec(v), base(b), stride(s), lanes(l) {
static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
}
};

template<typename Vec, typename Base, typename Stride, typename Lanes>
std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
return s;
}

template<typename Vec, typename Base, typename Stride, typename Lanes>
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
-> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
}

template<typename A>
struct Fold {
struct pattern_tag {};
Expand Down Expand Up @@ -2551,7 +2614,7 @@ std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
}

template<typename A>
struct HasEvenLanes {
struct LanesOf {
struct pattern_tag {};
A a;

Expand All @@ -2568,22 +2631,22 @@ struct HasEvenLanes {
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
Type t = a.make(state, {}).type();
val.u.u64 = (t.lanes() % 2 == 0);
val.u.u64 = t.lanes();
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = t.lanes();
ty.bits = 32;
ty.lanes = 1;
}
};

template<typename A>
HALIDE_ALWAYS_INLINE auto has_even_lanes(A &&a) noexcept -> HasEvenLanes<decltype(pattern_arg(a))> {
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}

template<typename A>
std::ostream &operator<<(std::ostream &s, const HasEvenLanes<A> &op) {
s << "has_even_lanes(" << op.a << ")";
std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
s << "lanes_of(" << op.a << ")";
return s;
}

Expand Down
22 changes: 10 additions & 12 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
rewrite(y*x + y*z, y*(x + z)) ||
rewrite(x*c0 + y*c1, (x + y*fold(c1/c0)) * c0, c1 % c0 == 0) ||
rewrite(x*c0 + y*c1, (x*fold(c0/c1) + y) * c1, c0 % c1 == 0) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) + slice(y, c0, c1, c2), slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (z + slice(y, c0, c1, c2)), slice(x + y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) + z), slice(x + y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (z - slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) - z), slice(x + y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||

(no_overflow(op->type) &&
(rewrite(x + x*y, x * (y + 1)) ||
rewrite(x + y*x, (y + 1) * x) ||
Expand Down Expand Up @@ -187,18 +197,6 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
return mutate(rewrite.result, bounds);
}
// clang-format on

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Add>(op);
} else {
return hoist_slice_vector<Add>(Add::make(a, b));
}
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
Expand Down
3 changes: 0 additions & 3 deletions src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,6 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
return f;
}

template<typename T>
Expr hoist_slice_vector(Expr e);

Stmt mutate_let_body(const Stmt &s, ExprInfo *) {
return mutate(s);
}
Expand Down
19 changes: 7 additions & 12 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) {

rewrite(max(select(x, y, z), select(x, w, u)), select(x, max(y, w), max(z, u))) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(max(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(max(x, y), c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(max(slice(x, c0, c1, c2), max(slice(y, c0, c1, c2), z)), max(slice(max(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(max(slice(x, c0, c1, c2), max(z, slice(y, c0, c1, c2))), max(slice(max(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||

(no_overflow(op->type) &&
(rewrite(max(max(x, y) + c0, x), max(x, y + c0), c0 < 0) ||
rewrite(max(max(x, y) + c0, x), max(x, y) + c0, c0 > 0) ||
Expand Down Expand Up @@ -299,18 +306,6 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) {
// clang-format on
}

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Max>(op);
} else {
return hoist_slice_vector<Max>(Max::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
18 changes: 6 additions & 12 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) {

rewrite(min(select(x, y, z), select(x, w, u)), select(x, min(y, w), min(z, u))) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(min(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(min(x, y), c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(min(slice(x, c0, c1, c2), min(slice(y, c0, c1, c2), z)), min(slice(min(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(min(slice(x, c0, c1, c2), min(z, slice(y, c0, c1, c2))), min(slice(min(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
(no_overflow(op->type) &&
(rewrite(min(min(x, y) + c0, x), min(x, y + c0), c0 > 0) ||
rewrite(min(min(x, y) + c0, x), min(x, y) + c0, c0 < 0) ||
Expand Down Expand Up @@ -311,18 +317,6 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) {
// clang-format on
}

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Min>(op);
} else {
return hoist_slice_vector<Min>(Min::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
20 changes: 8 additions & 12 deletions src/Simplify_Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,19 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) {
rewrite(ramp(x, y, c0) * broadcast(z, c0), ramp(x * z, y * z, c0)) ||
rewrite(ramp(broadcast(x, c0), broadcast(y, c0), c1) * broadcast(z, c2),
ramp(broadcast(x * z, c0), broadcast(y * z, c0), c1), c2 == c0 * c1) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) * slice(y, c0, c1, c2), slice(x * y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) * (slice(y, c0, c1, c2) * z), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) * (z * slice(y, c0, c1, c2)), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||

false) {
return mutate(rewrite.result, bounds);
}
}

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Mul>(op);
} else {
return hoist_slice_vector<Mul>(Mul::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
42 changes: 0 additions & 42 deletions src/Simplify_Shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,47 +321,5 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) {
}
}

template<typename T>
Expr Simplify::hoist_slice_vector(Expr e) {
const T *op = e.as<T>();
internal_assert(op);

const Shuffle *shuffle_a = op->a.template as<Shuffle>();
const Shuffle *shuffle_b = op->b.template as<Shuffle>();

internal_assert(shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice());

if (shuffle_a->indices != shuffle_b->indices) {
return e;
}

const std::vector<Expr> &slices_a = shuffle_a->vectors;
const std::vector<Expr> &slices_b = shuffle_b->vectors;
if (slices_a.size() != slices_b.size()) {
return e;
}

for (size_t i = 0; i < slices_a.size(); i++) {
if (slices_a[i].type() != slices_b[i].type()) {
return e;
}
}

vector<Expr> new_slices;
for (size_t i = 0; i < slices_a.size(); i++) {
new_slices.push_back(T::make(slices_a[i], slices_b[i]));
}

return Shuffle::make(new_slices, shuffle_a->indices);
}

template Expr Simplify::hoist_slice_vector<Add>(Expr);
template Expr Simplify::hoist_slice_vector<Sub>(Expr);
template Expr Simplify::hoist_slice_vector<Mul>(Expr);
template Expr Simplify::hoist_slice_vector<Min>(Expr);
template Expr Simplify::hoist_slice_vector<Max>(Expr);

} // namespace Internal
} // namespace Halide
21 changes: 9 additions & 12 deletions src/Simplify_Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) {
rewrite(x - x%c0, (x/c0)*c0) ||
rewrite(x - ((x + c0)/c1)*c1, (x + c0)%c1 - c0, c1 > 0) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) - (z + slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) - (slice(y, c0, c1, c2) + z), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite((slice(x, c0, c1, c2) - z) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite((z - slice(x, c0, c1, c2)) - slice(y, c0, c1, c2), z - slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||

(no_overflow(op->type) &&
(rewrite(max(x, y) - x, max(y - x, 0)) ||
rewrite(min(x, y) - x, min(y - x, 0)) ||
Expand Down Expand Up @@ -442,18 +451,6 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) {
}
// clang-format on

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Sub>(op);
} else {
return hoist_slice_vector<Sub>(Sub::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
17 changes: 17 additions & 0 deletions test/correctness/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,23 @@ void check_vectors() {
check(concat_vectors(loads), Load::make(Float(32, lanes * vectors), "buf", ramp(0, 1, lanes * vectors), Buffer<>(), Parameter(), const_true(vectors * lanes), ModulusRemainder(0, 0)));
}

{
Expr vx = Variable::make(Int(32, 32), "x");
Expr vy = Variable::make(Int(32, 32), "y");
Expr vz = Variable::make(Int(32, 8), "z");
Expr vw = Variable::make(Int(32, 16), "w");
// Check that vector slices are hoisted.
check(slice(vx, 0, 2, 8) + slice(vy, 0, 2, 8), slice(vx + vy, 0, 2, 8));
check(slice(vx, 0, 2, 8) + (slice(vy, 0, 2, 8) + vz), slice(vx + vy, 0, 2, 8) + vz);
check(slice(vx, 0, 2, 8) + (vz + slice(vy, 0, 2, 8)), slice(vx + vy, 0, 2, 8) + vz);
// Check that degenerate vector slices are not hoisted.
check(slice(vx, 0, 2, 1) + slice(vy, 0, 2, 1), slice(vx, 0, 2, 1) + slice(vy, 0, 2, 1));
check(slice(vx, 0, 2, 1) + (slice(vy, 0, 2, 1) + z), slice(vx, 0, 2, 1) + (slice(vy, 0, 2, 1) + z));
// Check slices are only hoisted when the lanes of the sliced vectors match.
check(slice(vx, 0, 2, 8) + slice(vw, 0, 2, 8), slice(vx, 0, 2, 8) + slice(vw, 0, 2, 8));
check(slice(vx, 0, 2, 8) + (slice(vw, 0, 2, 8) + vz), slice(vx, 0, 2, 8) + (slice(vw, 0, 2, 8) + vz));
}

{
// A predicated store with a provably-false predicate.
Expr pred = ramp(x * y + x * z, 2, 8) > 2;
Expand Down