Skip to content

Commit 97a3f96

Browse files
Xingyu Zhouwweic
authored andcommitted
[codegen] Add multiple operands and function support when using fp16 compilation (apache#4056)
* overload half operators for cuda codegen * add float16 te test_op_level1 * fix test_op_level1.py * fix lint * disable fp16 test if gpu does not support * disable fp16 test if gpu does not support * bypass float16 test if gpu does not support float16
1 parent d15429c commit 97a3f96

File tree

3 files changed

+251
-203
lines changed

3 files changed

+251
-203
lines changed

src/codegen/codegen_cuda.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
5050
std::string CodeGenCUDA::Finish() {
5151
if (enable_fp16_) {
5252
decl_stream << "#include <cuda_fp16.h>\n";
53+
decl_stream << "__device__ half max" \
54+
"(const half a, const half b)\n"
55+
"{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
56+
decl_stream << "__device__ half min(const half a, const half b)\n"
57+
"{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
58+
decl_stream << "__device__ half operator+" \
59+
"(const volatile __half &a, const volatile __half &b)\n"
60+
"{\n return __hadd(a, b);\n}\n";
61+
decl_stream << "__device__ half operator<=" \
62+
"(const volatile __half &a, const volatile __half &b)\n"
63+
"{\n return __hlt(a, b);\n}\n";
64+
decl_stream << "__device__ half operator*" \
65+
"(const volatile __half &a, const volatile __half &b)\n"
66+
"{\n return __hmul(a, b);\n}\n";
5367
}
5468

5569
if (enable_int8_) {

0 commit comments

Comments
 (0)