Skip to content

Commit

Permalink
Add cuda_toolkit rule and CudaToolkitInfo provider for tool File reso…
Browse files Browse the repository at this point in the history
…lution.
  • Loading branch information
cloudhan committed Jun 9, 2022
1 parent f2701cf commit 01124b1
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 24 deletions.
2 changes: 2 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_archs = "cuda_archs")
load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects")
load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library")
load("//cuda/private:rules/cuda_toolkit.bzl", _cuda_toolkit = "cuda_toolkit")
load(
"//cuda/private:toolchain.bzl",
_cuda_toolchain = "cuda_toolchain",
Expand All @@ -10,6 +11,7 @@ load(
load("//cuda/private:toolchain_configs/nvcc.bzl", _cuda_toolchain_config_nvcc = "cuda_toolchain_config")
load("//cuda/private:toolchain_configs/nvcc_msvc.bzl", _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config")

cuda_toolkit = _cuda_toolkit
cuda_toolchain = _cuda_toolchain
find_cuda_toolchain = _find_cuda_toolchain
use_cuda_toolchain = _use_cuda_toolchain
Expand Down
21 changes: 21 additions & 0 deletions cuda/dummy/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package(default_visibility = ["//visibility:public"])

cc_binary(
name = "nvlink",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=nvlink"],
)

exports_files(["link.stub"])

cc_binary(
name = "bin2c",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=bin2c"],
)

cc_binary(
name = "fatbinary",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=fatbinary"],
)
9 changes: 9 additions & 0 deletions cuda/dummy/dummy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "cstdio"

#define TO_STRING_IND(X) #X
#define TO_STRING(X) TO_STRING_IND(X)

int main(int argc, char* argv[]) {
std::printf("ERROR: " TO_STRING(TOOLNAME) " of cuda toolkit does not exist\n");
return -1;
}
1 change: 1 addition & 0 deletions cuda/dummy/link.stub
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#error link.stub of cuda toolkit does not exist
15 changes: 14 additions & 1 deletion cuda/private/providers.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,25 @@ CudaInfo = provider(
},
)

CudaToolkitInfo = provider(
"",
fields = {
"path": "string of path to cuda toolkit root",
"version_major": "int of the cuda toolkit major version, e.g, 11 for 11.6",
"version_minor": "int of the cuda toolkit minor version, e.g, 6 for 11.6",
"nvlink": "File to the nvlink executable",
"link_stub": "File to the link.stub file",
"bin2c": "File to the bin2c executable",
"fatbinary": "File to the fatbinary executable",
}
)

CudaToolchainConfigInfo = provider(
"""""",
fields = {
"action_configs": "A list of action_configs.",
"artifact_name_patterns": "A list of artifact_name_patterns.",
"cuda_path": "cuda toolkit root path",
"cuda_toolkit": "CudaToolkitInfo",
"features": "A list of features.",
"toolchain_identifier": "nvcc or clang",
},
Expand Down
39 changes: 27 additions & 12 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def _to_forward_slash(s):
def _is_linux(ctx):
return ctx.os.name.startswith("linux")

def _is_windows(ctx):
return ctx.os.name.lower().startswith("windows")

def _get_nvcc_version(repository_ctx, cuda_path):
result = repository_ctx.execute([cuda_path + "/bin/nvcc", "--version"])
if result.return_code != 0:
Expand All @@ -38,17 +41,6 @@ def _get_nvcc_version(repository_ctx, cuda_path):
return version[:2]
return [-1, -1]

CudaToolkitInfo = provider(
"",
fields = {
"path": "path, e.g. /usr/local/cuda",
"version_major": "int, e.g. 11",
"version_minor": "int, e.g. 6",
"nvcc_version_major": "int, e.g. 11",
"nvcc_version_minor": "int, e.g. 6",
},
)

def detect_cuda_toolkit(repository_ctx):
## Detect CUDA Toolkit
# Path to CUDA Toolkit is
Expand All @@ -67,20 +59,38 @@ def detect_cuda_toolkit(repository_ctx):
# if cuda_path == None:
# fail("Cannot determine CUDA Toolkit root, abort!")

bin_ext = ".exe" if _is_windows(repository_ctx) else ""
nvlink = "@rules_cuda//cuda/dummy:nvlink"
link_stub = "@rules_cuda//cuda/dummy:link.stub"
bin2c = "@rules_cuda//cuda/dummy:bin2c"
fatbinary = "@rules_cuda//cuda/dummy:fatbinary"
if repository_ctx.path(cuda_path + "/bin/nvlink" + bin_ext).exists:
nvlink = "@local_cuda//:cuda/bin/nvlink" + bin_ext
if repository_ctx.path(cuda_path + "/bin/crt/link.stub").exists:
link_stub = "@local_cuda//:cuda/bin/crt/link.stub"
if repository_ctx.path(cuda_path + "/bin/bin2c" + bin_ext).exists:
bin2c = "@local_cuda//:cuda/bin/bin2c" + bin_ext
if repository_ctx.path(cuda_path + "/bin/fatbinary" + bin_ext).exists:
fatbinary = "@local_cuda//:cuda/bin/fatbinary" + bin_ext

nvcc_version_major = -1
nvcc_version_minor = -1

if repository_ctx.path(cuda_path).exists:
nvcc_version_major, nvcc_version_minor = _get_nvcc_version(repository_ctx, cuda_path)

return CudaToolkitInfo(
return struct(
path = cuda_path,
# this should have been extracted from cuda.h, reuse nvcc for now
version_major = nvcc_version_major,
version_minor = nvcc_version_minor,
# this is extracted from `nvcc --version`
nvcc_version_major = nvcc_version_major,
nvcc_version_minor = nvcc_version_minor,
nvlink_label = nvlink,
link_stub_label = link_stub,
bin2c_label = bin2c,
fatbinary_label = fatbinary,
)

def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
Expand All @@ -103,8 +113,13 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
)
substitutions = {
"%{cuda_path}": _to_forward_slash(cuda.path),
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
"%{nvcc_version_major}": str(cuda.nvcc_version_major),
"%{nvcc_version_minor}": str(cuda.nvcc_version_minor),
"%{nvlink_label}": cuda.nvlink_label,
"%{link_stub_label}": cuda.link_stub_label,
"%{bin2c_label}": cuda.bin2c_label,
"%{fatbinary_label}": cuda.fatbinary_label,
}
env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None))
if env_tmp != None:
Expand Down
26 changes: 26 additions & 0 deletions cuda/private/rules/cuda_toolkit.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
load("//cuda/private:providers.bzl", "CudaToolkitInfo")

def _impl(ctx):
version_major, version_minor = ctx.attr.version.split(".")[:2]
return CudaToolkitInfo(
path = ctx.attr.path,
version_major = int(version_major),
version_minor = int(version_minor),
nvlink = ctx.file.nvlink,
link_stub = ctx.file.link_stub,
bin2c = ctx.file.bin2c,
fatbinary = ctx.file.fatbinary,
)

cuda_toolkit = rule(
implementation = _impl,
attrs = {
"path": attr.string(mandatory = True),
"version": attr.string(mandatory = True),
"nvlink": attr.label(allow_single_file = True),
"link_stub": attr.label(allow_single_file = True),
"bin2c": attr.label(allow_single_file = True),
"fatbinary": attr.label(allow_single_file = True),
},
provides = [CudaToolkitInfo],
)
6 changes: 5 additions & 1 deletion cuda/private/toolchain.bzl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo")
load("//cuda/private:toolchain_config_lib.bzl", "config_helper")

def _cuda_toolchain_impl(ctx):
Expand All @@ -19,6 +19,7 @@ def _cuda_toolchain_impl(ctx):
compiler_executable = ctx.attr.compiler_executable,
selectables_info = selectables_info,
artifact_name_patterns = artifact_name_patterns,
cuda_toolkit = cuda_toolchain_config.cuda_toolkit,
),
]

Expand Down Expand Up @@ -51,3 +52,6 @@ def use_cuda_toolchain():

def find_cuda_toolchain(ctx):
return ctx.toolchains[CUDA_TOOLCHAIN_TYPE]

def find_cuda_toolkit(ctx):
return ctx.toolchains[CUDA_TOOLCHAIN_TYPE].cuda_toolkit[CudaToolkitInfo]
6 changes: 3 additions & 3 deletions cuda/private/toolchain_configs/nvcc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("//cuda/private:action_names.bzl", "ACTION_NAMES")
load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo")
load("//cuda/private:toolchain.bzl", "use_cpp_toolchain")
load("//cuda/private:toolchain_configs/utils.bzl", "nvcc_version_ge")
load(
Expand Down Expand Up @@ -415,13 +415,13 @@ def _impl(ctx):
features = features,
artifact_name_patterns = artifact_name_patterns,
toolchain_identifier = ctx.attr.toolchain_identifier,
cuda_path = ctx.attr.cuda_path,
cuda_toolkit = ctx.attr.cuda_toolkit,
)]

cuda_toolchain_config = rule(
implementation = _impl,
attrs = {
"cuda_path": attr.string(default = "/usr/local/cuda"),
"cuda_toolkit": attr.label(mandatory = True, providers = [CudaToolkitInfo]),
"toolchain_identifier": attr.string(values = ["nvcc"], mandatory = True),
"nvcc_version_major": attr.int(),
"nvcc_version_minor": attr.int(),
Expand Down
6 changes: 3 additions & 3 deletions cuda/private/toolchain_configs/nvcc_msvc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("//cuda/private:action_names.bzl", "ACTION_NAMES")
load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo")
load("//cuda/private:toolchain.bzl", "use_cpp_toolchain")
load("//cuda/private:toolchain_configs/utils.bzl", "nvcc_version_ge")
load(
Expand Down Expand Up @@ -504,13 +504,13 @@ def _impl(ctx):
features = features,
artifact_name_patterns = artifact_name_patterns,
toolchain_identifier = ctx.attr.toolchain_identifier,
cuda_path = ctx.attr.cuda_path,
cuda_toolkit = ctx.attr.cuda_toolkit,
)]

cuda_toolchain_config = rule(
implementation = _impl,
attrs = {
"cuda_path": attr.string(default = "/usr/local/cuda"),
"cuda_toolkit": attr.label(mandatory = True, providers = [CudaToolkitInfo]),
"toolchain_identifier": attr.string(values = ["nvcc"], mandatory = True),
"nvcc_version_major": attr.int(),
"nvcc_version_minor": attr.int(),
Expand Down
15 changes: 13 additions & 2 deletions cuda/templates/BUILD.local_toolchain_nvcc
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@
load(
"@rules_cuda//cuda:defs.bzl",
"cuda_toolchain",
"cuda_toolkit",
cuda_toolchain_config = "cuda_toolchain_config_nvcc",
)

cuda_toolkit(
name = "cuda-toolkit",
bin2c = "%{bin2c_label}",
fatbinary = "%{bin2c_label}",
link_stub = "%{link_stub_label}",
nvlink = "%{nvlink_label}",
path = "%{cuda_path}",
version = "%{cuda_version}",
)

cuda_toolchain_config(
name = "nvcc-local-config",
cuda_path = "%{cuda_path}",
toolchain_identifier = "nvcc",
cuda_toolkit = ":cuda-toolkit",
nvcc_version_major = %{nvcc_version_major},
nvcc_version_minor = %{nvcc_version_minor},
toolchain_identifier = "nvcc",
)

cuda_toolchain(
Expand Down
15 changes: 13 additions & 2 deletions cuda/templates/BUILD.local_toolchain_nvcc_msvc
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@
load(
"@rules_cuda//cuda:defs.bzl",
"cuda_toolchain",
"cuda_toolkit",
cuda_toolchain_config = "cuda_toolchain_config_nvcc_msvc",
)

cuda_toolkit(
name = "cuda-toolkit",
bin2c = "%{bin2c_label}",
fatbinary = "%{bin2c_label}",
link_stub = "%{link_stub_label}",
nvlink = "%{nvlink_label}",
path = "%{cuda_path}",
version = "%{cuda_version}",
)

cuda_toolchain_config(
name = "nvcc-local-config",
cuda_path = "%{cuda_path}",
cuda_toolkit = ":cuda-toolkit",
msvc_env_tmp = "%{env_tmp}",
toolchain_identifier = "nvcc",
nvcc_version_major = %{nvcc_version_major},
nvcc_version_minor = %{nvcc_version_minor},
toolchain_identifier = "nvcc",
)

cuda_toolchain(
Expand Down

0 comments on commit 01124b1

Please sign in to comment.