Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to interpret a wide type as multiple smaller elements #6506

Merged
merged 2 commits into from
Jan 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/Simplify_Shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,72 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) {
new_vectors.push_back(new_vector);
}

// If any of the args are narrowing casts, convert them to shuffles of
// reinterpret casts so we can fold them into this shuffle. This all assumes
// little-endianness, so if we ever support a big-endian backend we'll have
// to switch on the target here.
for (Expr &v : new_vectors) {
if (!(v.type().is_int() || v.type().is_uint()) || !v.as<Cast>()) {
continue;
}

auto x_is_16_bit = is_int(x, 16) || is_uint(x, 16);
auto x_is_32_bit = is_int(x, 32) || is_uint(x, 32);
auto x_is_64_bit = is_int(x, 64) || is_uint(x, 64);

auto rewrite = IRMatcher::rewriter(v, op->type);

auto t8 = v.type().with_bits(8);
auto t16 = v.type().with_bits(16);
auto t32 = v.type().with_bits(32);

// Shifts have been canonicalized to divisions provided they are less
// than 32 bit, so the patterns below switch from division to shifting
// at 32-bits.
int stride = 0, start = 0;

if (rewrite(cast(t8, x / (1 << 8)), x, x_is_16_bit) ||
rewrite(cast(t16, x / (1 << 16)), x, x_is_32_bit) ||
rewrite(cast(t32, shift_right(x, 32)), x, x_is_64_bit)) {
// Extract high half
stride = 2;
start = 1;
} else if (rewrite(cast(t8, x), x, x_is_16_bit) ||
rewrite(cast(t16, x), x, x_is_32_bit) ||
rewrite(cast(t32, x), x, x_is_64_bit)) {
// Extract low half
stride = 2;
start = 0;
} else if (rewrite(cast(t8, x / (1 << 24)), x, x_is_32_bit) ||
rewrite(cast(t16, shift_right(x, 48)), x, x_is_64_bit)) {
// Extract 4th quarter
stride = 4;
start = 3;
} else if (rewrite(cast(t8, x / (1 << 16)), x, x_is_32_bit) ||
rewrite(cast(t16, shift_right(x, 32)), x, x_is_64_bit)) {
// Extract 3rd quarter
stride = 4;
start = 2;
} else if (rewrite(cast(t8, x / (1 << 8)), x, x_is_32_bit) ||
rewrite(cast(t16, x / (1 << 16)), x, x_is_64_bit)) {
// Extract 2nd quarter
stride = 4;
start = 1;
} else if (rewrite(cast(t8, x), x, x_is_32_bit) ||
rewrite(cast(t16, x), x, x_is_64_bit)) {
// Extract low quarter
stride = 4;
start = 0;
} else {
continue;
}

int lanes = v.type().lanes();
v = reinterpret(v.type().with_lanes(lanes * stride), rewrite.result);
v = Shuffle::make_slice(v, start, stride, lanes);
changed = true;
}

// Try to convert a load with shuffled indices into a
// shuffle of a dense load.
if (const Load *first_load = new_vectors[0].as<Load>()) {
Expand Down Expand Up @@ -191,6 +257,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) {
}
}
}

} else if (op->is_concat()) {
// Try to collapse a concat of ramps into a single ramp.
const Ramp *r = new_vectors[0].as<Ramp>();
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ tests(GROUPS correctness
reduction_non_rectangular.cpp
reduction_schedule.cpp
register_shuffle.cpp
reinterpret_vector.cpp
reorder_rvars.cpp
reorder_storage.cpp
require.cpp
Expand Down
90 changes: 90 additions & 0 deletions test/correctness/reinterpret_vector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include "Halide.h"

using namespace Halide;
using namespace Halide::Internal;

class CheckNoVectorMath : public IRMutator {
public:
using IRMutator::mutate;
Expr mutate(const Expr &e) override {
IRMutator::mutate(e);
// An allow-list of IR nodes we are OK with
if (e.type().is_vector() &&
!(Call::as_intrinsic(e, {Call::reinterpret}) ||
e.as<Load>() ||
e.as<Ramp>() ||
e.as<Variable>() ||
e.as<Broadcast>())) {
std::cout << "Unexpected vector expression: " << e << "\n";
exit(-1);
}

return e;
}
};

int main(int argc, char **argv) {
// Check we can treat a vector of a wide type as a wider vector of a
// narrower type for free.
Var x, y, c;

// Treat a 32-bit image as a twice-as-wide 16-bit image
{
Func narrow, wide;
wide(x, y) = cast<uint32_t>(x + y);
narrow(x, y) = select(x % 2 == 0,
cast<uint16_t>(wide(x / 2, y)),
cast<uint16_t>(wide(x / 2, y) >> 16));
wide.compute_root();
narrow.align_bounds(x, 16).vectorize(x, 16);
CheckNoVectorMath checker;
narrow.add_custom_lowering_pass(&checker, nullptr);

Buffer<uint16_t> out = narrow.realize({1024, 1024});

for (int y = 0; y < out.height(); y++) {
for (int x = 0; x < out.width(); x++) {
int correct = ((x % 2 == 0) ? x / 2 + y : (x / 2 + y) >> 16);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n", x, y, out(x, y), correct);
return -1;
}
}
}
}

// Treat 2-dimensional 32-bit values representing rgba as 3-dimensional rgba
{
Func rgba_packed, rgba;
rgba_packed(x, y) = cast<uint32_t>(x + y);
rgba(c, x, y) = mux(c, {cast<uint8_t>(rgba_packed(x, y)),
cast<uint8_t>(rgba_packed(x, y) >> 8),
cast<uint8_t>(rgba_packed(x, y) >> 16),
cast<uint8_t>(rgba_packed(x, y) >> 24)});
rgba_packed.compute_root();
rgba.align_bounds(x, 16).vectorize(x, 16).bound(c, 0, 4).unroll(c);
rgba.output_buffer().dim(1).set_stride(4);
CheckNoVectorMath checker;
rgba.add_custom_lowering_pass(&checker, nullptr);

Buffer<uint8_t> out = rgba.realize({3, 1024, 1024});

for (int y = 0; y < out.dim(2).extent(); y++) {
for (int x = 0; x < out.dim(1).extent(); x++) {
for (int c = 0; c < out.dim(0).extent(); c++) {
uint8_t correct = (c == 0) ? (x + y) :
(c == 1) ? (x + y) >> 8 :
(c == 2) ? (x + y) >> 16 :
(x + y) >> 24;
if (out(c, x, y) != correct) {
printf("out(%d, %d, %d) = %d instead of %d\n", c, x, y, out(x, y), correct);
return -1;
}
}
}
}
}

printf("Success!\n");
return 0;
}