Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bool conversion bug in Vulkan code generator #8067

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/CodeGen_Vulkan_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,9 @@ void fill_bytes_with_value(uint8_t *bytes, int count, int value) {
}

SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type value_type, SpvId value_id) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(): casting from value type '"
<< value_type << "' to target type '" << target_type << "' for value id '" << value_id << "' !\n";

if (!value_type.is_bool()) {
value_id = cast_type(Bool(), value_type, value_id);
}
Expand Down Expand Up @@ -590,8 +593,8 @@ SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type

SpvId result_id = builder.reserve_id(SpvResultId);
SpvId target_type_id = builder.declare_type(target_type);
SpvId true_value_id = builder.declare_constant(target_type, &true_data);
SpvId false_value_id = builder.declare_constant(target_type, &false_data);
SpvId true_value_id = builder.declare_constant(target_type, &true_data[0]);
SpvId false_value_id = builder.declare_constant(target_type, &false_data[0]);
builder.append(SpvFactory::select(target_type_id, result_id, value_id, true_value_id, false_value_id));
return result_id;
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ tests(GROUPS correctness
bit_counting.cpp
bitwise_ops.cpp
bool_compute_root_vectorize.cpp
bool_predicate_cast.cpp
bound.cpp
bound_small_allocations.cpp
bound_storage.cpp
Expand Down
39 changes: 39 additions & 0 deletions test/correctness/bool_predicate_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {

// Test explicit casting of a predicate to an integer as part of a reduction
// NOTE: triggers a convert_to_bool in Vulkan for a SelectOp
Target target = get_jit_target_from_environment();
Var x("x"), y("y");

Func input("input");
input(x, y) = cast<uint8_t>(x + y);

Func test("test");
test(x, y) = cast(UInt(8), input(x, y) >= 32);

if (target.has_gpu_feature()) {
Var xi("xi"), yi("yi");
test.gpu_tile(x, y, xi, yi, 8, 8);
}

Realization result = test.realize({96, 96});
Buffer<uint8_t> a = result[0];
for (int y = 0; y < a.height(); y++) {
for (int x = 0; x < a.width(); x++) {
uint8_t correct_a = ((x + y) >= 32) ? 1 : 0;
if (a(x, y) != correct_a) {
printf("result(%d, %d) = (%d) instead of (%d)\n",
x, y, a(x, y), correct_a);
return 1;
}
}
}

printf("Success!\n");
return 0;
}
Loading