diff --git a/cuda/defs.bzl b/cuda/defs.bzl index d5d83c62..5dfae715 100644 --- a/cuda/defs.bzl +++ b/cuda/defs.bzl @@ -1,6 +1,7 @@ load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_archs = "cuda_archs") load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects") load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") +load("//cuda/private:rules/cuda_toolkit.bzl", _cuda_toolkit = "cuda_toolkit") load( "//cuda/private:toolchain.bzl", _cuda_toolchain = "cuda_toolchain", @@ -10,6 +11,7 @@ load( load("//cuda/private:toolchain_configs/nvcc.bzl", _cuda_toolchain_config_nvcc = "cuda_toolchain_config") load("//cuda/private:toolchain_configs/nvcc_msvc.bzl", _cuda_toolchain_config_nvcc_msvc = "cuda_toolchain_config") +cuda_toolkit = _cuda_toolkit cuda_toolchain = _cuda_toolchain find_cuda_toolchain = _find_cuda_toolchain use_cuda_toolchain = _use_cuda_toolchain diff --git a/cuda/dummy/BUILD.bazel b/cuda/dummy/BUILD.bazel new file mode 100644 index 00000000..87ee3af0 --- /dev/null +++ b/cuda/dummy/BUILD.bazel @@ -0,0 +1,21 @@ +package(default_visibility = ["//visibility:public"]) + +cc_binary( + name = "nvlink", + srcs = ["dummy.cpp"], + defines = ["TOOLNAME=nvlink"], +) + +exports_files(["link.stub"]) + +cc_binary( + name = "bin2c", + srcs = ["dummy.cpp"], + defines = ["TOOLNAME=bin2c"], +) + +cc_binary( + name = "fatbinary", + srcs = ["dummy.cpp"], + defines = ["TOOLNAME=fatbinary"], +) diff --git a/cuda/dummy/dummy.cpp b/cuda/dummy/dummy.cpp new file mode 100644 index 00000000..a5fa4bad --- /dev/null +++ b/cuda/dummy/dummy.cpp @@ -0,0 +1,9 @@ +#include "cstdio" + +#define TO_STRING_IND(X) #X +#define TO_STRING(X) TO_STRING_IND(X) + +int main(int argc, char* argv[]) { + std::printf("ERROR: " TO_STRING(TOOLNAME) " of cuda toolkit does not exist\n"); + return -1; +} diff --git a/cuda/dummy/link.stub b/cuda/dummy/link.stub new file mode 100644 index 00000000..cb1b2e62 --- /dev/null +++ b/cuda/dummy/link.stub @@ -0,0 +1 @@ +#error link.stub of cuda toolkit does not exist diff --git a/cuda/private/providers.bzl b/cuda/private/providers.bzl index 6d0a84ed..045b3e2e 100644 --- a/cuda/private/providers.bzl +++ b/cuda/private/providers.bzl @@ -87,12 +87,25 @@ CudaInfo = provider( }, ) +CudaToolkitInfo = provider( + "", + fields = { + "path": "string of path to cuda toolkit root", + "version_major": "int of the cuda toolkit major version, e.g, 11 for 11.6", + "version_minor": "int of the cuda toolkit minor version, e.g, 6 for 11.6", + "nvlink": "File to the nvlink executable", + "link_stub": "File to the link.stub file", + "bin2c": "File to the bin2c executable", + "fatbinary": "File to the fatbinary executable", + } +) + CudaToolchainConfigInfo = provider( """""", fields = { "action_configs": "A list of action_configs.", "artifact_name_patterns": "A list of artifact_name_patterns.", - "cuda_path": "cuda toolkit root path", + "cuda_toolkit": "CudaToolkitInfo", "features": "A list of features.", "toolchain_identifier": "nvcc or clang", }, diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index 81712080..22458b37 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -25,6 +25,9 @@ def _to_forward_slash(s): def _is_linux(ctx): return ctx.os.name.startswith("linux") +def _is_windows(ctx): + return ctx.os.name.lower().startswith("windows") + def _get_nvcc_version(repository_ctx, cuda_path): result = repository_ctx.execute([cuda_path + "/bin/nvcc", "--version"]) if result.return_code != 0: @@ -38,17 +41,6 @@ def _get_nvcc_version(repository_ctx, cuda_path): return version[:2] return [-1, -1] -CudaToolkitInfo = provider( - "", - fields = { - "path": "path, e.g. /usr/local/cuda", - "version_major": "int, e.g. 11", - "version_minor": "int, e.g. 6", - "nvcc_version_major": "int, e.g. 11", - "nvcc_version_minor": "int, e.g. 6", - }, -) - def detect_cuda_toolkit(repository_ctx): ## Detect CUDA Toolkit # Path to CUDA Toolkit is @@ -67,13 +59,27 @@ def detect_cuda_toolkit(repository_ctx): # if cuda_path == None: # fail("Cannot determine CUDA Toolkit root, abort!") + bin_ext = ".exe" if _is_windows(repository_ctx) else "" + nvlink = "@rules_cuda//cuda/dummy:nvlink" + link_stub = "@rules_cuda//cuda/dummy:link.stub" + bin2c = "@rules_cuda//cuda/dummy:bin2c" + fatbinary = "@rules_cuda//cuda/dummy:fatbinary" + if repository_ctx.path(cuda_path + "/bin/nvlink" + bin_ext).exists: + nvlink = "@local_cuda//:cuda/bin/nvlink" + bin_ext + if repository_ctx.path(cuda_path + "/bin/crt/link.stub").exists: + link_stub = "@local_cuda//:cuda/bin/crt/link.stub" + if repository_ctx.path(cuda_path + "/bin/bin2c" + bin_ext).exists: + bin2c = "@local_cuda//:cuda/bin/bin2c" + bin_ext + if repository_ctx.path(cuda_path + "/bin/fatbinary" + bin_ext).exists: + fatbinary = "@local_cuda//:cuda/bin/fatbinary" + bin_ext + nvcc_version_major = -1 nvcc_version_minor = -1 if repository_ctx.path(cuda_path).exists: nvcc_version_major, nvcc_version_minor = _get_nvcc_version(repository_ctx, cuda_path) - return CudaToolkitInfo( + return struct( path = cuda_path, # this should have been extracted from cuda.h, reuse nvcc for now version_major = nvcc_version_major, @@ -81,6 +87,10 @@ def detect_cuda_toolkit(repository_ctx): # this is extracted from `nvcc --version` nvcc_version_major = nvcc_version_major, nvcc_version_minor = nvcc_version_minor, + nvlink_label = nvlink, + link_stub_label = link_stub, + bin2c_label = bin2c, + fatbinary_label = fatbinary, ) def config_cuda_toolkit_and_nvcc(repository_ctx, cuda): @@ -103,8 +113,13 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda): ) substitutions = { "%{cuda_path}": _to_forward_slash(cuda.path), + "%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor), "%{nvcc_version_major}": str(cuda.nvcc_version_major), "%{nvcc_version_minor}": str(cuda.nvcc_version_minor), + "%{nvlink_label}": cuda.nvlink_label, + "%{link_stub_label}": cuda.link_stub_label, + "%{bin2c_label}": cuda.bin2c_label, + "%{fatbinary_label}": cuda.fatbinary_label, } env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None)) if env_tmp != None: diff --git a/cuda/private/rules/cuda_toolkit.bzl b/cuda/private/rules/cuda_toolkit.bzl new file mode 100644 index 00000000..8afdc268 --- /dev/null +++ b/cuda/private/rules/cuda_toolkit.bzl @@ -0,0 +1,26 @@ +load("//cuda/private:providers.bzl", "CudaToolkitInfo") + +def _impl(ctx): + version_major, version_minor = ctx.attr.version.split(".")[:2] + return CudaToolkitInfo( + path = ctx.attr.path, + version_major = int(version_major), + version_minor = int(version_minor), + nvlink = ctx.file.nvlink, + link_stub = ctx.file.link_stub, + bin2c = ctx.file.bin2c, + fatbinary = ctx.file.fatbinary, + ) + +cuda_toolkit = rule( + implementation = _impl, + attrs = { + "path": attr.string(mandatory = True), + "version": attr.string(mandatory = True), + "nvlink": attr.label(allow_single_file = True), + "link_stub": attr.label(allow_single_file = True), + "bin2c": attr.label(allow_single_file = True), + "fatbinary": attr.label(allow_single_file = True), + }, + provides = [CudaToolkitInfo], +) diff --git a/cuda/private/toolchain.bzl b/cuda/private/toolchain.bzl index 58f3bf1a..2e602fcf 100644 --- a/cuda/private/toolchain.bzl +++ b/cuda/private/toolchain.bzl @@ -1,4 +1,4 @@ -load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo") +load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo") load("//cuda/private:toolchain_config_lib.bzl", "config_helper") def _cuda_toolchain_impl(ctx): @@ -19,6 +19,7 @@ def _cuda_toolchain_impl(ctx): compiler_executable = ctx.attr.compiler_executable, selectables_info = selectables_info, artifact_name_patterns = artifact_name_patterns, + cuda_toolkit = cuda_toolchain_config.cuda_toolkit, ), ] @@ -51,3 +52,6 @@ def use_cuda_toolchain(): def find_cuda_toolchain(ctx): return ctx.toolchains[CUDA_TOOLCHAIN_TYPE] + +def find_cuda_toolkit(ctx): + return ctx.toolchains[CUDA_TOOLCHAIN_TYPE].cuda_toolkit[CudaToolkitInfo] diff --git a/cuda/private/toolchain_configs/nvcc.bzl b/cuda/private/toolchain_configs/nvcc.bzl index 4d79a932..28671400 100644 --- a/cuda/private/toolchain_configs/nvcc.bzl +++ b/cuda/private/toolchain_configs/nvcc.bzl @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") load("//cuda/private:action_names.bzl", "ACTION_NAMES") load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES") -load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo") +load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo") load("//cuda/private:toolchain.bzl", "use_cpp_toolchain") load("//cuda/private:toolchain_configs/utils.bzl", "nvcc_version_ge") load( @@ -415,13 +415,13 @@ def _impl(ctx): features = features, artifact_name_patterns = artifact_name_patterns, toolchain_identifier = ctx.attr.toolchain_identifier, - cuda_path = ctx.attr.cuda_path, + cuda_toolkit = ctx.attr.cuda_toolkit, )] cuda_toolchain_config = rule( implementation = _impl, attrs = { - "cuda_path": attr.string(default = "/usr/local/cuda"), + "cuda_toolkit": attr.label(mandatory = True, providers = [CudaToolkitInfo]), "toolchain_identifier": attr.string(values = ["nvcc"], mandatory = True), "nvcc_version_major": attr.int(), "nvcc_version_minor": attr.int(), diff --git a/cuda/private/toolchain_configs/nvcc_msvc.bzl b/cuda/private/toolchain_configs/nvcc_msvc.bzl index 97dd07e2..88427892 100644 --- a/cuda/private/toolchain_configs/nvcc_msvc.bzl +++ b/cuda/private/toolchain_configs/nvcc_msvc.bzl @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") load("//cuda/private:action_names.bzl", "ACTION_NAMES") load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES") -load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo") +load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo") load("//cuda/private:toolchain.bzl", "use_cpp_toolchain") load("//cuda/private:toolchain_configs/utils.bzl", "nvcc_version_ge") load( @@ -504,13 +504,13 @@ def _impl(ctx): features = features, artifact_name_patterns = artifact_name_patterns, toolchain_identifier = ctx.attr.toolchain_identifier, - cuda_path = ctx.attr.cuda_path, + cuda_toolkit = ctx.attr.cuda_toolkit, )] cuda_toolchain_config = rule( implementation = _impl, attrs = { - "cuda_path": attr.string(default = "/usr/local/cuda"), + "cuda_toolkit": attr.label(mandatory = True, providers = [CudaToolkitInfo]), "toolchain_identifier": attr.string(values = ["nvcc"], mandatory = True), "nvcc_version_major": attr.int(), "nvcc_version_minor": attr.int(), diff --git a/cuda/templates/BUILD.local_toolchain_nvcc b/cuda/templates/BUILD.local_toolchain_nvcc index 2639e56d..862e399f 100644 --- a/cuda/templates/BUILD.local_toolchain_nvcc +++ b/cuda/templates/BUILD.local_toolchain_nvcc @@ -3,15 +3,26 @@ load( "@rules_cuda//cuda:defs.bzl", "cuda_toolchain", + "cuda_toolkit", cuda_toolchain_config = "cuda_toolchain_config_nvcc", ) +cuda_toolkit( + name = "cuda-toolkit", + bin2c = "%{bin2c_label}", + fatbinary = "%{bin2c_label}", + link_stub = "%{link_stub_label}", + nvlink = "%{nvlink_label}", + path = "%{cuda_path}", + version = "%{cuda_version}", +) + cuda_toolchain_config( name = "nvcc-local-config", - cuda_path = "%{cuda_path}", - toolchain_identifier = "nvcc", + cuda_toolkit = ":cuda-toolkit", nvcc_version_major = %{nvcc_version_major}, nvcc_version_minor = %{nvcc_version_minor}, + toolchain_identifier = "nvcc", ) cuda_toolchain( diff --git a/cuda/templates/BUILD.local_toolchain_nvcc_msvc b/cuda/templates/BUILD.local_toolchain_nvcc_msvc index 5298676a..f8e40ae5 100644 --- a/cuda/templates/BUILD.local_toolchain_nvcc_msvc +++ b/cuda/templates/BUILD.local_toolchain_nvcc_msvc @@ -3,16 +3,27 @@ load( "@rules_cuda//cuda:defs.bzl", "cuda_toolchain", + "cuda_toolkit", cuda_toolchain_config = "cuda_toolchain_config_nvcc_msvc", ) +cuda_toolkit( + name = "cuda-toolkit", + bin2c = "%{bin2c_label}", + fatbinary = "%{bin2c_label}", + link_stub = "%{link_stub_label}", + nvlink = "%{nvlink_label}", + path = "%{cuda_path}", + version = "%{cuda_version}", +) + cuda_toolchain_config( name = "nvcc-local-config", - cuda_path = "%{cuda_path}", + cuda_toolkit = ":cuda-toolkit", msvc_env_tmp = "%{env_tmp}", - toolchain_identifier = "nvcc", nvcc_version_major = %{nvcc_version_major}, nvcc_version_minor = %{nvcc_version_minor}, + toolchain_identifier = "nvcc", ) cuda_toolchain(