Skip to content

[SYCL] Add support for NVPTX device printf #4293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions clang/lib/CodeGen/CGGPUBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,19 @@ CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E,

// Invoke vprintf and return.
llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule());
return RValue::get(Builder.CreateCall(
VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr}));
auto FormatSpecifier = Args[0].getRValue(*this).getScalarVal();
// Check if the format specifier is in the constant address space, vprintf is
// oblivious to address spaces, so it would have to be casted away.
if (Args[0]
.getRValue(*this)
.getScalarVal()
->getType()
->getPointerAddressSpace() == 4)
FormatSpecifier = Builder.CreateAddrSpaceCast(
FormatSpecifier, llvm::Type::getInt8PtrTy(Ctx));

return RValue::get(
Builder.CreateCall(VprintfFunc, {FormatSpecifier, BufferPtr}));
}

RValue
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ static bool IsSyclMathFunc(unsigned BuiltinID) {
bool Sema::isKnownGoodSYCLDecl(const Decl *D) {
if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
const IdentifierInfo *II = FD->getIdentifier();
if (FD->getBuiltinID() == Builtin::BIprintf)
return true;
Comment on lines +423 to +424
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging @AaronBallman, @elizabethandrews and @premanandrao for awareness: this if should probably be extended to check if we are compiling for CUDA or not.

The thing is that printf built-in as-is won't be properly converted into SPIR-V and will cause JIT compilation errors on OpenCL backends.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing it to check for CUDA makes sense to me, good catch!

const DeclContext *DC = FD->getDeclContext();
if (II && II->isStr("__spirv_ocl_printf") &&
!FD->isDefined() &&
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))

extern "C" int printf(const char* fmt, ...);

// Dummy runtime classes to model SYCL API.
inline namespace cl {
namespace sycl {
Expand Down Expand Up @@ -310,6 +312,21 @@ class spec_constant {
return get();
}
};

#ifdef __SYCL_DEVICE_ONLY__
#define __SYCL_CONSTANT_AS __attribute__((opencl_constant))
#else
#define __SYCL_CONSTANT_AS
#endif
template <typename... Args>
int printf(const __SYCL_CONSTANT_AS char *__format, Args... args) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
return __spirv_ocl_printf(__format, args...);
#else
return ::printf(__format, args...);
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
}

} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
12 changes: 12 additions & 0 deletions clang/test/CodeGenSYCL/nvptx-printf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: %clang_cc1 -fsycl-is-device -triple nvptx64-nvidia-cuda-sycldevice -std=c++11 -S -emit-llvm -x c++ %s -o - | FileCheck %s

#include "Inputs/sycl.hpp"

static const __SYCL_CONSTANT_AS char format_2[] = "Hello! %d %f\n";

int main() {
// Make sure that device printf is dispatched to CUDA's vprintf syscall.
// CHECK: alloca %printf_args
// CHECK: call i32 @vprintf
cl::sycl::kernel_single_task<class first_kernel>([]() { cl::sycl::ext::oneapi::experimental::printf(format_2, 123, 1.23); });
}
4 changes: 2 additions & 2 deletions sycl/include/sycl/ext/oneapi/experimental/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ namespace experimental {
//
template <typename... Args>
int printf(const __SYCL_CONSTANT_AS char *__format, Args... args) {
#ifdef __SYCL_DEVICE_ONLY__
#if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
return __spirv_ocl_printf(__format, args...);
#else
return ::printf(__format, args...);
#endif
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
}

} // namespace experimental
Expand Down