Skip to content

Commit d4fb0a2

Browse files
icemelontqchen
authored andcommitted
[BugFix] Fix bug in cast to bool (#3207)
1 parent 1b35903 commit d4fb0a2

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/codegen/llvm/codegen_llvm.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
537537
if (value->getType() == target) return value;
538538
if (to.is_handle()) {
539539
return builder_->CreateBitCast(value, target);
540+
} else if (to.is_uint() && to.bits() == 1) {
541+
if (from.is_float()) {
542+
llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
543+
return builder_->CreateFCmpONE(value, zero);
544+
} else {
545+
llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
546+
return builder_->CreateICmpNE(value, zero);
547+
}
540548
} else if (!from.is_float() && !to.is_float()) {
541549
return builder_->CreateIntCast(value, target, from.is_int());
542550
} else if (from.is_float() && to.is_int()) {

topi/tests/python/test_topi_math.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import topi
2020
import topi.testing
2121
from topi import util
22+
from common import get_all_backend
2223

2324

2425
def test_util():
@@ -59,8 +60,7 @@ def check_device(device):
5960
foo(a, b)
6061
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
6162

62-
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel',
63-
'aocl_sw_emu']:
63+
for device in get_all_backend():
6464
check_device(device)
6565

6666

@@ -77,6 +77,46 @@ def check_device(device):
7777
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
7878
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
7979

80+
81+
def test_cast():
82+
def verify(from_dtype, to_dtype, low=-100, high=100):
83+
shape = (5, 4)
84+
A = tvm.placeholder(shape, dtype=from_dtype, name="A")
85+
B = topi.cast(A, to_dtype)
86+
87+
if from_dtype == "bool":
88+
a_np = np.random.choice([True, False], size=shape)
89+
else:
90+
a_np = np.random.uniform(low, high, size=shape).astype(from_dtype)
91+
if to_dtype == "bool":
92+
a_np = a_np - a_np[2, 3]
93+
b_np = a_np.astype(to_dtype)
94+
95+
for device in get_all_backend():
96+
ctx = tvm.context(device, 0)
97+
if not ctx.exist:
98+
print("Skip because %s is not enabled" % device)
99+
continue
100+
print("Running on target: %s" % device)
101+
with tvm.target.create(device):
102+
s = topi.generic.schedule_injective(B)
103+
foo = tvm.build(s, [A, B], device)
104+
a = tvm.nd.array(a_np, ctx)
105+
b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
106+
foo(a, b)
107+
tvm.testing.assert_allclose(b.asnumpy(), b_np)
108+
109+
verify("int32", "float32")
110+
verify("int32", "float64")
111+
verify("int32", "bool")
112+
verify("float32", "int32")
113+
verify("float32", "float64")
114+
verify("float32", "bool")
115+
verify("bool", "float32")
116+
verify("bool", "int32")
117+
118+
80119
if __name__ == "__main__":
81120
test_util()
82121
test_ewise()
122+
test_cast()

0 commit comments

Comments
 (0)