Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_fla

package(default_visibility = ["//visibility:public"])

CUDA_PLATFORMS = [
"linux-x86_64",
"linux-sbsa",
"linux-aarch64",
"linux-ppc64le",
"windows-x86_64",
]

bzl_library(
name = "bzl_srcs",
srcs = glob(["*.bzl"]),
Expand Down Expand Up @@ -38,6 +46,37 @@ config_setting(
flag_values = {"@cuda//:valid_toolchain_found": "True"},
)

string_flag(
name = "version",
build_setting_default = "13.0.0",
)

string_flag(
name = "target_platform",
build_setting_default = "linux-x86_64",
values = CUDA_PLATFORMS,
)

[
config_setting(
name = "target_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":target_platform": platform},
) for platform in CUDA_PLATFORMS
]

string_flag(
name = "exec_platform",
build_setting_default = "linux-x86_64",
values = CUDA_PLATFORMS,
)

[
config_setting(
name = "exec_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":exec_platform": platform},
) for platform in CUDA_PLATFORMS
]

# Command line flag to specify the list of CUDA architectures to compile for.
#
# Provides CudaArchsInfo of the list of archs to build.
Expand Down
3 changes: 3 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Core rules for building CUDA projects.
"""

load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda")
load("//cuda/private:errors.bzl", _unsupported_cuda_version = "unsupported_cuda_version", _unsupported_cuda_platform = "unsupported_cuda_platform")
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
Expand Down Expand Up @@ -47,3 +48,5 @@ if_windows = _if_windows
cc_import_versioned_sos = _cc_import_versioned_sos

requires_cuda = _requires_cuda
unsupported_cuda_version = _unsupported_cuda_version
unsupported_cuda_platform = _unsupported_cuda_platform
15 changes: 14 additions & 1 deletion cuda/dummy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@ cc_binary(
defines = ["TOOLNAME=nvcc"],
)

cc_binary(
name = "cicc",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=cicc"],
)

cc_binary(
name = "nvlink",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=nvlink"],
)

exports_files(["link.stub"])
exports_files(["link.stub", "libdevice.10.bc"])

cc_binary(
name = "bin2c",
Expand All @@ -25,3 +31,10 @@ cc_binary(
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=fatbinary"],
)

# Empty cc_library that provides CcInfo for components not available in this CUDA version.
cc_library(
name = "dummy",
srcs = [],
hdrs = [],
)
1 change: 1 addition & 0 deletions cuda/dummy/libdevice.10.bc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#error libdevice.10.bc of cuda toolkit does not exist
124 changes: 109 additions & 15 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")

cuda_component_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
Expand Down Expand Up @@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = {
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
"platforms": attr.string_list(
doc = "A list of platforms to generate components for.",
),
"version": attr.string(
doc = "Generate a URL by using the specified version." +
"This URL will be tried after all URLs specified in the `urls` attribute.",
Expand All @@ -74,6 +78,40 @@ cuda_toolkit_tag = tag_class(attrs = {
),
})

platform_alias_tag = tag_class(
attrs = {
"name": attr.string(
mandatory = True,
doc = "Name of the alias repository to create",
),
"component_name": attr.string(
mandatory = True,
doc = "Name of the component to create aliases for",
),
"linux_x86_64_repo": attr.string(
mandatory = True,
doc = "Name of the repository to use for x86_64 platform",
),
"linux_aarch64_repo": attr.string(
mandatory = True,
doc = "Name of the repository to use for ARM64/Jetpack platform",
),
"linux_sbsa_repo": attr.string(
mandatory = True,
doc = "Name of the repository to use for SBSA platform",
),
Comment on lines +91 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would a platform_mapping with attr.string_dict be better here? Validating the key in the rule impl seem to be future proof.

"versions": attr.string_list(
mandatory = True,
doc = "List of versions to create aliases for",
),
},
doc = """Defines a platform-specific alias repository.

Each alias tag creates a repository with targets that select between
x86_64 and ARM64 repositories based on the build platform.
""",
)

def _find_modules(module_ctx):
root = None
our_module = None
Expand All @@ -95,17 +133,20 @@ def _module_tag_to_dict(t):
def _redist_json_impl(module_ctx, attr):
url, json_object = redist_json_helper.get(module_ctx, attr)
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url)

mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = "@" + repo_name
platform_mapping = {}
for platform in attr.platforms:
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = repo_name

attr = {key: value for key, value in spec.items()}
attr["name"] = repo_name
cuda_component(**attr)
return redist_ver, mapping
component_attr = {key: value for key, value in spec.items()}
component_attr["name"] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
cuda_component(**component_attr)
platform_mapping[platform] = mapping
return redist_ver, platform_mapping

def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
Expand All @@ -117,22 +158,53 @@ def _impl(module_ctx):
components = root.tags.component
redist_jsons = root.tags.redist_json
toolkits = root.tags.toolkit
platform_aliases = root.tags.platform_alias
else:
components = rules_cuda.tags.component
redist_jsons = rules_cuda.tags.redist_json
toolkits = rules_cuda.tags.toolkit

platform_aliases = rules_cuda.tags.platform_alias
for component in components:
cuda_component(**_module_tag_to_dict(component))

if len(redist_jsons) > 1:
fail("Using multiple cuda.redist_json is not supported yet.")

redist_version = None
components_mapping = None
redist_versions = []
redist_components_mapping = {}

# Track all versioned repositories for each component and platform.
versioned_repos = {}
for redist_json in redist_jsons:
redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json)
components_mapping = {}
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json)
redist_versions.append(redist_version)
for platform in platform_mapping.keys():
for component_name, repo_name in platform_mapping[platform].items():
redist_components_mapping[component_name] = repo_name

# Track the versioned repo name for this component/platform/version.
if component_name not in versioned_repos:
versioned_repos[component_name] = {}
if platform not in versioned_repos[component_name]:
versioned_repos[component_name][platform] = {}
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")

for component_name in redist_components_mapping.keys():
# Build dictionaries mapping versions to repo names for each platform.
x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]}
aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]}
sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]}

platform_alias_repo(
name = redist_components_mapping[component_name],
repo_name = redist_components_mapping[component_name],
component_name = component_name,
linux_x86_64_repos = x86_64_repos,
linux_aarch64_repos = aarch64_repos,
linux_sbsa_repos = sbsa_repos,
versions = redist_versions,
)
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
registrations = {}
for toolkit in toolkits:
if toolkit.name in registrations.keys():
Expand All @@ -148,15 +220,37 @@ def _impl(module_ctx):

for _, toolkit in registrations.items():
if components_mapping != None:
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version)
# Always use the maximum version so the toolkit includes all components.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not quite true if CTK delete some component in the future. I think a union across all CTK versions will be a little bit more robust.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could take some work, right now the version in cuda_toolkit isn't going to necessarily be correct since it's pointing to @cuda which can point to any number of versioned cuda repos, but I don't know if that gets used anywhere in the rules so I'll try removing it and see what falls out...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is pretty deeply embedded in the repository rules where I can't use the value of a flag. I might need to go back and add the ability to register multiple toolkits to get everything to work as expected...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets leave it for future improvement, just point it out :)

# Components that don't exist in older versions will fall back to dummy.
toolkit_version = redist_versions[0]
for ver in redist_versions:
ver_parts = [int(x) for x in ver.split(".")]
tv_parts = [int(x) for x in toolkit_version.split(".")]
if ver_parts > tv_parts:
toolkit_version = ver

cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
else:
cuda_toolkit(**_module_tag_to_dict(toolkit))

for alias_tag in platform_aliases:
# Create a repository for each alias tag
platform_alias_repo(
name = alias_tag.name,
repo_name = alias_tag.name,
component_name = alias_tag.component_name,
linux_x86_64_repo = alias_tag.linux_x86_64_repo,
linux_aarch64_repo = alias_tag.linux_aarch64_repo,
linux_sbsa_repo = alias_tag.linux_sbsa_repo,
versions = alias_tag.versions,
)

toolchain = module_extension(
implementation = _impl,
tag_classes = {
"component": cuda_component_tag,
"redist_json": cuda_redist_json_tag,
"toolkit": cuda_toolkit_tag,
"platform_alias": platform_alias_tag,
},
)
Loading
Loading