Skip to content

Specialize does not remove select conditions as expected #8443

Closed
@yyuting

Description

@yyuting

When trying to specialize the conditions of a nested select, Halide cannot remove select within the specialization as expected.
In the generated stmt file, the specialize condition in the outer if/else branch is rewritten into something different than the select condition.

I'm using Halide 18.0.0 on Mac M2.

Here's a minimal generator example:

// specialize_generator.cpp
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

class SpecializeBugGenerator
    : public Halide::Generator<SpecializeBugGenerator> {
public:
  Input<Buffer<float>> input{"input", 2};
  Input<float> scale_factor_x{"scale_factor_x"};
  Input<float> scale_factor_y{"scale_factor_y"};

  Output<Buffer<float>> output{"output", 2};

  Var x, y;

  void generate() {

    Expr upsample_x = scale_factor_x > 1.0f;
    Expr upsample_y = scale_factor_y > 1.0f;
    Expr upsample = upsample_x && upsample_y;
    Expr downsample = !upsample_x && !upsample_y;

    output(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
                          select(downsample, input(x * 2, y * 2), 0.0f));

    output.specialize(upsample).specialize(downsample);
    output.specialize(upsample).specialize(!downsample);
    output.specialize(!upsample).specialize(downsample);
    output.specialize(!upsample).specialize(!downsample);
  }
};

HALIDE_REGISTER_GENERATOR(SpecializeBugGenerator, specialize_bug_generator)

Here's part of the generated stmt.

...
if (1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) {
   if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
    let t61 = 1.000000f < (float32)scale_factor_y
    let t60 = 1.000000f < (float32)scale_factor_x
    let t63 = 0 - (output.min.1*output.stride.1)
    let t62 = (input.min.1*input.stride.1) + input.min.0
    for (output.s0.v1.rebased, 0, output.extent.1) {
     let t66 = t60 || t61
     let t65 = t60 && t61
     let t64 = output.min.1 + output.s0.v1.rebased
     for (output.s0.v0.rebased, 0, output.extent.0) {
      output[((output.stride.1*t64) + t63) + output.s0.v0.rebased] = select(t65, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t64/2)*input.stride.1) - t62)], select(t66, 0.000000f, input[((((input.stride.1*t64) + output.min.0) + output.s0.v0.rebased)*2) - t62]))
     }
    }
   }
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions