Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type error in VectorizeLoops #8055

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Interval bounds_of_lanes(const Expr &e) {
Interval ia = bounds_of_lanes(not_->a);
return {!ia.max, !ia.min};
} else if (const Ramp *r = e.as<Ramp>()) {
Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1);
Expr last_lane_idx = make_const(r->base.type().element_of(), r->lanes - 1);
Interval ib = bounds_of_lanes(r->base);
const Broadcast *b = as_scalar_broadcast(r->stride);
Expr stride = b ? b->value : r->stride;
Expand Down Expand Up @@ -875,6 +875,7 @@ class VectorSubs : public IRMutator {
// generating a scalar condition that checks if
// the least-true lane is true.
Expr all_true = bounds_of_lanes(likely->args[0]).min;
internal_assert(all_true.type() == Bool());
// Wrap it in the same flavor of likely
all_true = Call::make(Bool(), likely->name,
{all_true}, Call::PureIntrinsic);
Expand Down
68 changes: 68 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,74 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/8054
{
ImageParam input(Float(32), 2, "input");
const float r_sigma = 0.1;
const int s_sigma = 8;
Func bilateral_grid{"bilateral_grid"};

Var x("x"), y("y"), z("z"), c("c");

// Add a boundary condition
Func clamped = Halide::BoundaryConditions::repeat_edge(input);

// Construct the bilateral grid
RDom r(0, s_sigma, 0, s_sigma);
Expr val = clamped(x * s_sigma + r.x - s_sigma / 2, y * s_sigma + r.y - s_sigma / 2);
val = clamp(val, 0.0f, 1.0f);

Expr zi = cast<int>(val * (1.0f / r_sigma) + 0.5f);

Func histogram("histogram");
histogram(x, y, z, c) = 0.0f;
histogram(x, y, zi, c) += mux(c, {val, 1.0f});

// Blur the grid using a five-tap filter
Func blurx("blurx"), blury("blury"), blurz("blurz");
blurz(x, y, z, c) = (histogram(x, y, z - 2, c) +
histogram(x, y, z - 1, c) * 4 +
histogram(x, y, z, c) * 6 +
histogram(x, y, z + 1, c) * 4 +
histogram(x, y, z + 2, c));
blurx(x, y, z, c) = (blurz(x - 2, y, z, c) +
blurz(x - 1, y, z, c) * 4 +
blurz(x, y, z, c) * 6 +
blurz(x + 1, y, z, c) * 4 +
blurz(x + 2, y, z, c));
blury(x, y, z, c) = (blurx(x, y - 2, z, c) +
blurx(x, y - 1, z, c) * 4 +
blurx(x, y, z, c) * 6 +
blurx(x, y + 1, z, c) * 4 +
blurx(x, y + 2, z, c));

// Take trilinear samples to compute the output
val = clamp(input(x, y), 0.0f, 1.0f);
Expr zv = val * (1.0f / r_sigma);
zi = cast<int>(zv);
Expr zf = zv - zi;
Expr xf = cast<float>(x % s_sigma) / s_sigma;
Expr yf = cast<float>(y % s_sigma) / s_sigma;
Expr xi = x / s_sigma;
Expr yi = y / s_sigma;
Func interpolated("interpolated");
interpolated(x, y, c) =
lerp(lerp(lerp(blury(xi, yi, zi, c), blury(xi + 1, yi, zi, c), xf),
lerp(blury(xi, yi + 1, zi, c), blury(xi + 1, yi + 1, zi, c), xf), yf),
lerp(lerp(blury(xi, yi, zi + 1, c), blury(xi + 1, yi, zi + 1, c), xf),
lerp(blury(xi, yi + 1, zi + 1, c), blury(xi + 1, yi + 1, zi + 1, c), xf), yf),
zf);

// Normalize
bilateral_grid(x, y) = interpolated(x, y, 0) / interpolated(x, y, 1);
Pipeline p({bilateral_grid});

Var v6, zo, vzi;

blury.compute_root().split(x, x, v6, 6, TailStrategy::GuardWithIf).split(z, zo, vzi, 8, TailStrategy::GuardWithIf).reorder(y, x, c, vzi, zo, v6).vectorize(vzi).vectorize(v6);
p.compile_to_module({input}, "bilateral_grid", {Target("host")});
}

printf("Success!\n");
return 0;
}
Loading