Skip to content

Commit

Permalink
Add clang toolchain config and some other minor related changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Jun 9, 2022
1 parent c92eaf8 commit c13ebaa
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 1 deletion.
11 changes: 10 additions & 1 deletion cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ cuda_archs_flag(
build_setting_default = "",
)

#TODO:
# Command line flag to select compiler for cuda_library() code.
string_flag(
name = "compiler",
Expand All @@ -32,6 +31,16 @@ string_flag(
],
)

config_setting(
name = "compiler_is_nvcc",
flag_values = {":compiler": "nvcc"},
)

config_setting(
name = "compiler_is_clang",
flag_values = {":compiler": "clang"},
)

#TODO:
# Command line flag for copts to add to cuda_library() compile command.
string_list_flag(
Expand Down
2 changes: 2 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ load(
_find_cuda_toolchain = "find_cuda_toolchain",
_use_cuda_toolchain = "use_cuda_toolchain",
)
load("//cuda/private:toolchain_configs/clang.bzl", _cuda_toolchain_config_clang = "cuda_toolchain_config")
load("//cuda/private:toolchain_configs/nvcc.bzl", _cuda_toolchain_config_nvcc = "cuda_toolchain_config")
load("//cuda/private:toolchain_configs/nvcc_msvc.bzl", _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config")

cuda_toolkit = _cuda_toolkit
cuda_toolchain = _cuda_toolchain
find_cuda_toolchain = _find_cuda_toolchain
use_cuda_toolchain = _use_cuda_toolchain
cuda_toolchain_config_clang = _cuda_toolchain_config_clang
cuda_toolchain_config_nvcc_msvc = _cuda_toolchain_config_nvcc_msvc
cuda_toolchain_config_nvcc = _cuda_toolchain_config_nvcc

Expand Down
1 change: 1 addition & 0 deletions cuda/private/cuda_toolkit.bzl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
def register_detected_cuda_toolchains():
native.register_toolchains(
"@local_cuda//toolchain:nvcc-local-toolchain",
"@local_cuda//toolchain/clang:clang-local-toolchain",
)
34 changes: 34 additions & 0 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,43 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
substitutions["%{env_tmp}"] = _to_forward_slash(env_tmp)
repository_ctx.template("toolchain/BUILD", tpl_label, substitutions = substitutions, executable = False)

def detect_clang(repository_ctx):
## Detect clang executable
# Path to clang is
# - taken from CUDA_CLANG_PATH environment variable or
# - taken from BAZEL_LLVM environment variable as <BAZEL_LLVM>/bin/clang or
# - determined through 'which clang' or
# - treated as being not detected and not configured
bin_ext = ".exe" if _is_windows(repository_ctx) else ""
clang_path = repository_ctx.os.environ.get("CUDA_CLANG_PATH", None)
if clang_path == None:
bazel_llvm = repository_ctx.os.environ.get("BAZEL_LLVM", None)
if bazel_llvm != None and repository_ctx.path(bazel_llvm + "/bin/clang" + bin_ext).exists:
clang_path = bazel_llvm + "/bin/clang" + bin_ext
if clang_path == None:
clang_path = str(repository_ctx.which("clang"))
return clang_path

def config_clang(repository_ctx, cuda, clang_path):
# Generate @local_cuda//toolchain/clang/BUILD
tpl_label = Label("//cuda:templates/BUILD.local_toolchain_clang")
substitutions = {
"%{clang_path}": _to_forward_slash(clang_path),
"%{cuda_path}": _to_forward_slash(cuda.path),
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
"%{nvlink_label}": cuda.nvlink_label,
"%{link_stub_label}": cuda.link_stub_label,
"%{bin2c_label}": cuda.bin2c_label,
"%{fatbinary_label}": cuda.fatbinary_label,
}
repository_ctx.template("toolchain/clang/BUILD", tpl_label, substitutions = substitutions, executable = False)

def _local_cuda_impl(repository_ctx):
cuda = detect_cuda_toolkit(repository_ctx)
config_cuda_toolkit_and_nvcc(repository_ctx, cuda)
clang_path = detect_clang(repository_ctx)
if clang_path != None:
config_clang(repository_ctx, cuda, clang_path)

_local_cuda = repository_rule(
implementation = _local_cuda_impl,
Expand Down
Loading

0 comments on commit c13ebaa

Please sign in to comment.