-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support clang as a CUDA compiler #4
Comments
Yeah, there is plan. I have local modification but it doesn't make it now. |
@Artem-B Considering you are one of the devs of clang cuda, I have a question about it. I now have a nearly working configuration for clang. The only problem I am facing is, for example: __global__ void kernel() {
// blahblah with impl
} with nvcc, the compiled object have symbol
but with clang version 15.0.0 (https://github.com/llvm/llvm-project.git 009d56da5c4ea3666c4753ce7564c8c20d7e0255)
the |
First, kernels and stubs. Clang indeed no longer has NCCL does rely on RDC compilation (i.e. each source compiles to a GPU object file, instead of a fully linked GPU executable) and that part works very differently in clang vs nvcc. In a nutshell, object files need an extra final linking step, and a bit of extra host-side 'glue' code. NVCC does that under the hood. Clang does not, yet. Here's how tensorflow implements @jhuber6 has been working on clang driver changes to make compile-to-object-and-link "just work" on the GPU side. E.g. llvm/llvm-project@b7c8c4d |
Also, I believe CMake has recently added support for clang as the CUDA compiler. It may be worth checking whether/how they handle RDC compilation there. |
I think the expected way to perform RDC-mode compilation is via the CUDA_SEPARABLE_COMPILATION option. I think this is supported for Clang as well judging by this issue. |
Judging by the commit that has implemented it in cmake they did use the same RDC compilation process that we've implemented in tensorflow that I've pointed to above. You may as well just pick up tensorflow's implementation directly. @chsigg would probably be the most familiar with the details, if you have questions. |
OK, problem solved, it turns out that I need compile all C code of nccl as cuda with |
@Artem-B I think this is addressed in c13ebaa Use As you are a member of tensorflow, I am wondering if this can be mentioned or evaluated in tf. Might be good for bazel community ;) |
Not necessary production ready, but at least usable. It needs more users to test it out before I can say it is production ready. Because it is a build system, there are too many corner cases in it. |
One thing that could serve as a motivation to adopt these changes would be to try getting Tensorflow to build using your rules, instead of the ones TF carries. It would be a pretty decent test of the real-world usability of the rules -- TF is probably the largest bazel user outside of Google and is almost certainly the largest user of clang for CUDA compilations. Having them convinced would go a long way towards convincing bazel owners that these rules should to be part of bazel. Having a proof of concept at that level would also give TF owners rough idea how much work it would take to adopt it and whether it's worth it. One thing to keep in mind is that TF also has to work with our internal build. I don't know yet how hard it would be to switch to your rules. If it's a drop-in replacement of the |
@Artem-B Do we have prebuilt llvm package with NVPTX backend enabled. I'd like adding a building CI. So that I can confidently close this issue finally. |
LLVM/Clang releases should have NVPTX built in. E.g https://github.com/llvm/llvm-project/releases/tag/llvmorg-14.0.6 On a side note, just a FYI that there's been a lot of offloading-related changes in clang driver lately that are going to make GPU compilation much closer to C++ compilation. E.g. RDC compilation would "just work" -- |
If you want to try out the new driver I would appreciate it. For compiling an application in RDC mode you can do the following.
Right now what's missing from the new driver is support for textures / surfaces, Windows / MacOS support, and compiling in non-RDC mode. The benefits are simplified compilation, static library support, and LTO among others. |
It'd be good to be able to load clang from https://github.com/grailbio/bazel-toolchain so that we can have a hermetic toolchain setup. I'll probably look into this at some point soon as we're already using that toolchain for our host builds and will be using rules_cuda soon within one of our product builds. |
llvm apt clang is also built with NVPTX enabled, we can use that too. |
This is partially fixed by #143. Later I will add a full integration test by adding nccl as an example. The cloudhan/nccl-example branch should be buildable with both clang and nvcc. |
Is there any flags I should add besides maybe those:
in theory for this to work ? I'm having a weird issue: eveything compiles fine, but then on execution it just dies without any output. Maybe I'm living a bit too close to the edge using clang 17 and CUDA 12.1 ? It does say it's only partially supported... My whole setup is available here: https://github.com/hypdeb/lawrencium. |
At least running something like this should be definitely supported. The only time I've seen errors like this in the past is when there's no supported architecture it tends to just silently die. E.g. if I compile for
If you're using RDC-mode w/ clang you'll need to opt-in.
Using |
Thanks for the extremely fast and detailed response. I just tried a few things based on your inputs, but no luck. I should add that I'm working in Ubuntu 22.04 in WSL 2 if it's relevant. The same code was running fine a few versions ago using |
Does the tool |
@hypdeb Could you please try |
If we're doing |
I am already using clang as my cc compiler: https://github.com/hypdeb/lawrencium/blob/1694b0f1707d2bc6d2a782a734749ae1c1379336/toolchain/cc_toolchain_config.bzl#L24
among many others. |
Here are the exact commands run by Bazel:
and then
|
I think it's unlikely the issue is with |
Close with #158 |
Clang is capable of CUDA compilation these days.
It would be great to add support for using it for CUDA compilation with bazel.
The text was updated successfully, but these errors were encountered: