Skip to content

Commit bb00a15

Browse files
[CUDA][CodeGen] Fix cuda codegen's fp16 inf literal (#12581)
* Fix cuda codegen's fp16 inf literal * add relay testcase
1 parent 21db1eb commit bb00a15

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,10 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p)
11971197
break;
11981198
}
11991199
case 16: {
1200-
os << "__float2half_rn";
1201-
os << '(' << std::scientific << op->value << 'f' << ')';
1200+
os << "__float2half_rn" << '(';
1201+
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
1202+
PrintConst(const_f32.get(), os, p);
1203+
os << ')';
12021204
break;
12031205
}
12041206
default:

tests/python/relay/test_op_level3.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0, indices_dtype="int32"
13441344
verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1, indices_dtype="uint32")
13451345

13461346

1347-
def _verify_infiniteness_ops(relay_op, ref_op):
1347+
def _verify_infiniteness_ops(relay_op, ref_op, target="llvm", dev=None):
13481348
for dtype in ["float32", "float16", "float16", "int32", "int16"]:
13491349
shape = (2, 8, 8)
13501350
x = relay.var("x", relay.TensorType(shape, dtype))
@@ -1359,17 +1359,25 @@ def _verify_infiniteness_ops(relay_op, ref_op):
13591359
] = np.infty
13601360
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan
13611361

1362-
op_res = create_executor().evaluate(y, {x: data})
1362+
op_res = create_executor(target=target, device=dev).evaluate(y, {x: data})
13631363
ref_res = ref_op(data)
13641364
np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01)
13651365

13661366

1367+
@tvm.testing.requires_gpu
13671368
def test_isfinite():
1368-
_verify_infiniteness_ops(relay.isfinite, np.isfinite)
1369+
for target, dev in tvm.testing.enabled_targets():
1370+
if target not in ["llvm", "cuda"]:
1371+
continue
1372+
_verify_infiniteness_ops(relay.isfinite, np.isfinite, target=target, dev=dev)
13691373

13701374

1375+
@tvm.testing.requires_gpu
13711376
def test_isinf():
1372-
_verify_infiniteness_ops(relay.isinf, np.isinf)
1377+
for target, dev in tvm.testing.enabled_targets():
1378+
if target not in ["llvm", "cuda"]:
1379+
continue
1380+
_verify_infiniteness_ops(relay.isinf, np.isinf, target=target, dev=dev)
13731381

13741382

13751383
def test_unravel_index(target, dev, executor_kind):

0 commit comments

Comments
 (0)