Closed
Description
When trying to specialize the conditions of a nested select
, Halide cannot remove select within the specialization as expected.
In the generated stmt file, the specialize condition in the outer if/else branch is rewritten into something different than the select condition.
I'm using Halide 18.0.0 on Mac M2.
Here's a minimal generator example:
// specialize_generator.cpp
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
class SpecializeBugGenerator
: public Halide::Generator<SpecializeBugGenerator> {
public:
Input<Buffer<float>> input{"input", 2};
Input<float> scale_factor_x{"scale_factor_x"};
Input<float> scale_factor_y{"scale_factor_y"};
Output<Buffer<float>> output{"output", 2};
Var x, y;
void generate() {
Expr upsample_x = scale_factor_x > 1.0f;
Expr upsample_y = scale_factor_y > 1.0f;
Expr upsample = upsample_x && upsample_y;
Expr downsample = !upsample_x && !upsample_y;
output(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
select(downsample, input(x * 2, y * 2), 0.0f));
output.specialize(upsample).specialize(downsample);
output.specialize(upsample).specialize(!downsample);
output.specialize(!upsample).specialize(downsample);
output.specialize(!upsample).specialize(!downsample);
}
};
HALIDE_REGISTER_GENERATOR(SpecializeBugGenerator, specialize_bug_generator)
Here's part of the generated stmt.
...
if (1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) {
if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
let t61 = 1.000000f < (float32)scale_factor_y
let t60 = 1.000000f < (float32)scale_factor_x
let t63 = 0 - (output.min.1*output.stride.1)
let t62 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t66 = t60 || t61
let t65 = t60 && t61
let t64 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t64) + t63) + output.s0.v0.rebased] = select(t65, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t64/2)*input.stride.1) - t62)], select(t66, 0.000000f, input[((((input.stride.1*t64) + output.min.0) + output.s0.v0.rebased)*2) - t62]))
}
}
}
...
Metadata
Metadata
Assignees
Labels
No labels