Skip to content

Commit 4af2eb5

Browse files
authored
[SYCL] Add support for NVPTX device printf (#4293)
Use `::printf` when not compiling for `__SPIR__`, this allows the use of `EmitNVPTXDevicePrintfCallExpr` which packs the var args and dispatches to CUDA's `vprintf`. Fixes #1154
1 parent 65dfb9d commit 4af2eb5

File tree

5 files changed

+46
-4
lines changed

5 files changed

+46
-4
lines changed

clang/lib/CodeGen/CGGPUBuiltin.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,19 @@ CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E,
118118

119119
// Invoke vprintf and return.
120120
llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule());
121-
return RValue::get(Builder.CreateCall(
122-
VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr}));
121+
auto FormatSpecifier = Args[0].getRValue(*this).getScalarVal();
122+
// Check if the format specifier is in the constant address space, vprintf is
123+
// oblivious to address spaces, so it would have to be casted away.
124+
if (Args[0]
125+
.getRValue(*this)
126+
.getScalarVal()
127+
->getType()
128+
->getPointerAddressSpace() == 4)
129+
FormatSpecifier = Builder.CreateAddrSpaceCast(
130+
FormatSpecifier, llvm::Type::getInt8PtrTy(Ctx));
131+
132+
return RValue::get(
133+
Builder.CreateCall(VprintfFunc, {FormatSpecifier, BufferPtr}));
123134
}
124135

125136
RValue

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,8 @@ static bool IsSyclMathFunc(unsigned BuiltinID) {
420420
bool Sema::isKnownGoodSYCLDecl(const Decl *D) {
421421
if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
422422
const IdentifierInfo *II = FD->getIdentifier();
423+
if (FD->getBuiltinID() == Builtin::BIprintf)
424+
return true;
423425
const DeclContext *DC = FD->getDeclContext();
424426
if (II && II->isStr("__spirv_ocl_printf") &&
425427
!FD->isDefined() &&

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
44

5+
extern "C" int printf(const char* fmt, ...);
6+
57
// Dummy runtime classes to model SYCL API.
68
inline namespace cl {
79
namespace sycl {
@@ -310,6 +312,21 @@ class spec_constant {
310312
return get();
311313
}
312314
};
315+
316+
#ifdef __SYCL_DEVICE_ONLY__
317+
#define __SYCL_CONSTANT_AS __attribute__((opencl_constant))
318+
#else
319+
#define __SYCL_CONSTANT_AS
320+
#endif
321+
template <typename... Args>
322+
int printf(const __SYCL_CONSTANT_AS char *__format, Args... args) {
323+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
324+
return __spirv_ocl_printf(__format, args...);
325+
#else
326+
return ::printf(__format, args...);
327+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
328+
}
329+
313330
} // namespace experimental
314331
} // namespace oneapi
315332
} // namespace ext
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -triple nvptx64-nvidia-cuda-sycldevice -std=c++11 -S -emit-llvm -x c++ %s -o - | FileCheck %s
2+
3+
#include "Inputs/sycl.hpp"
4+
5+
static const __SYCL_CONSTANT_AS char format_2[] = "Hello! %d %f\n";
6+
7+
int main() {
8+
// Make sure that device printf is dispatched to CUDA's vprintf syscall.
9+
// CHECK: alloca %printf_args
10+
// CHECK: call i32 @vprintf
11+
cl::sycl::kernel_single_task<class first_kernel>([]() { cl::sycl::ext::oneapi::experimental::printf(format_2, 123, 1.23); });
12+
}

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ namespace experimental {
6161
//
6262
template <typename... Args>
6363
int printf(const __SYCL_CONSTANT_AS char *__format, Args... args) {
64-
#ifdef __SYCL_DEVICE_ONLY__
64+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
6565
return __spirv_ocl_printf(__format, args...);
6666
#else
6767
return ::printf(__format, args...);
68-
#endif
68+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
6969
}
7070

7171
} // namespace experimental

0 commit comments

Comments
 (0)