-
-
Notifications
You must be signed in to change notification settings - Fork 64
Add support for multi-arch and multi-platform cuda toolchains #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6f3d716
235a94a
d6a7478
5e605df
7bc9c61
3e1f206
53f739e
4233145
9fbd1ab
d9cfe1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| #error libdevice.10.bc of cuda toolkit does not exist |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| load("//cuda/private:redist_json_helper.bzl", "redist_json_helper") | ||
| load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit") | ||
| load("//cuda:platform_alias_extension.bzl", "platform_alias_repo") | ||
|
|
||
| cuda_component_tag = tag_class(attrs = { | ||
| "name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"), | ||
|
|
@@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = { | |
| "URLs are tried in order until one succeeds, so you should list local mirrors first. " + | ||
| "If all downloads fail, the rule will fail.", | ||
| ), | ||
| "platforms": attr.string_list( | ||
| doc = "A list of platforms to generate components for.", | ||
| ), | ||
| "version": attr.string( | ||
| doc = "Generate a URL by using the specified version." + | ||
| "This URL will be tried after all URLs specified in the `urls` attribute.", | ||
|
|
@@ -74,6 +78,40 @@ cuda_toolkit_tag = tag_class(attrs = { | |
| ), | ||
| }) | ||
|
|
||
| platform_alias_tag = tag_class( | ||
| attrs = { | ||
| "name": attr.string( | ||
| mandatory = True, | ||
| doc = "Name of the alias repository to create", | ||
| ), | ||
| "component_name": attr.string( | ||
| mandatory = True, | ||
| doc = "Name of the component to create aliases for", | ||
| ), | ||
| "linux_x86_64_repo": attr.string( | ||
| mandatory = True, | ||
| doc = "Name of the repository to use for x86_64 platform", | ||
| ), | ||
| "linux_aarch64_repo": attr.string( | ||
| mandatory = True, | ||
| doc = "Name of the repository to use for ARM64/Jetpack platform", | ||
| ), | ||
| "linux_sbsa_repo": attr.string( | ||
| mandatory = True, | ||
| doc = "Name of the repository to use for SBSA platform", | ||
| ), | ||
| "versions": attr.string_list( | ||
| mandatory = True, | ||
| doc = "List of versions to create aliases for", | ||
| ), | ||
| }, | ||
| doc = """Defines a platform-specific alias repository. | ||
|
|
||
| Each alias tag creates a repository with targets that select between | ||
| x86_64 and ARM64 repositories based on the build platform. | ||
| """, | ||
| ) | ||
|
|
||
| def _find_modules(module_ctx): | ||
| root = None | ||
| our_module = None | ||
|
|
@@ -95,17 +133,20 @@ def _module_tag_to_dict(t): | |
| def _redist_json_impl(module_ctx, attr): | ||
| url, json_object = redist_json_helper.get(module_ctx, attr) | ||
| redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object) | ||
| component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url) | ||
|
|
||
| mapping = {} | ||
| for spec in component_specs: | ||
| repo_name = redist_json_helper.get_repo_name(module_ctx, spec) | ||
| mapping[spec["component_name"]] = "@" + repo_name | ||
| platform_mapping = {} | ||
| for platform in attr.platforms: | ||
| component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url) | ||
| mapping = {} | ||
| for spec in component_specs: | ||
| repo_name = redist_json_helper.get_repo_name(module_ctx, spec) | ||
| mapping[spec["component_name"]] = repo_name | ||
|
|
||
| attr = {key: value for key, value in spec.items()} | ||
| attr["name"] = repo_name | ||
| cuda_component(**attr) | ||
| return redist_ver, mapping | ||
| component_attr = {key: value for key, value in spec.items()} | ||
| component_attr["name"] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_") | ||
| cuda_component(**component_attr) | ||
| platform_mapping[platform] = mapping | ||
| return redist_ver, platform_mapping | ||
|
|
||
| def _impl(module_ctx): | ||
| # Toolchain configuration is only allowed in the root module, or in rules_cuda. | ||
|
|
@@ -117,22 +158,53 @@ def _impl(module_ctx): | |
| components = root.tags.component | ||
| redist_jsons = root.tags.redist_json | ||
| toolkits = root.tags.toolkit | ||
| platform_aliases = root.tags.platform_alias | ||
| else: | ||
| components = rules_cuda.tags.component | ||
| redist_jsons = rules_cuda.tags.redist_json | ||
| toolkits = rules_cuda.tags.toolkit | ||
|
|
||
| platform_aliases = rules_cuda.tags.platform_alias | ||
| for component in components: | ||
| cuda_component(**_module_tag_to_dict(component)) | ||
|
|
||
| if len(redist_jsons) > 1: | ||
| fail("Using multiple cuda.redist_json is not supported yet.") | ||
|
|
||
| redist_version = None | ||
| components_mapping = None | ||
| redist_versions = [] | ||
| redist_components_mapping = {} | ||
|
|
||
| # Track all versioned repositories for each component and platform. | ||
| versioned_repos = {} | ||
| for redist_json in redist_jsons: | ||
| redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json) | ||
| components_mapping = {} | ||
| redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json) | ||
| redist_versions.append(redist_version) | ||
| for platform in platform_mapping.keys(): | ||
| for component_name, repo_name in platform_mapping[platform].items(): | ||
| redist_components_mapping[component_name] = repo_name | ||
|
|
||
| # Track the versioned repo name for this component/platform/version. | ||
| if component_name not in versioned_repos: | ||
| versioned_repos[component_name] = {} | ||
| if platform not in versioned_repos[component_name]: | ||
| versioned_repos[component_name][platform] = {} | ||
| versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_") | ||
|
|
||
| for component_name in redist_components_mapping.keys(): | ||
| # Build dictionaries mapping versions to repo names for each platform. | ||
| x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]} | ||
| aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]} | ||
| sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]} | ||
|
|
||
| platform_alias_repo( | ||
| name = redist_components_mapping[component_name], | ||
| repo_name = redist_components_mapping[component_name], | ||
| component_name = component_name, | ||
| linux_x86_64_repos = x86_64_repos, | ||
| linux_aarch64_repos = aarch64_repos, | ||
| linux_sbsa_repos = sbsa_repos, | ||
| versions = redist_versions, | ||
| ) | ||
| components_mapping[component_name] = "@" + redist_components_mapping[component_name] | ||
| registrations = {} | ||
| for toolkit in toolkits: | ||
| if toolkit.name in registrations.keys(): | ||
|
|
@@ -148,15 +220,37 @@ def _impl(module_ctx): | |
|
|
||
| for _, toolkit in registrations.items(): | ||
| if components_mapping != None: | ||
| cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version) | ||
| # Always use the maximum version so the toolkit includes all components. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not quite true if CTK delete some component in the future. I think a union across all CTK versions will be a little bit more robust.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could take some work, right now the version in cuda_toolkit isn't going to necessarily be correct since it's pointing to @cuda which can point to any number of versioned cuda repos, but I don't know if that gets used anywhere in the rules so I'll try removing it and see what falls out...
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic is pretty deeply embedded in the repository rules where I can't use the value of a flag. I might need to go back and add the ability to register multiple toolkits to get everything to work as expected...
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets leave it for future improvement, just point it out :) |
||
| # Components that don't exist in older versions will fall back to dummy. | ||
| toolkit_version = redist_versions[0] | ||
| for ver in redist_versions: | ||
| ver_parts = [int(x) for x in ver.split(".")] | ||
| tv_parts = [int(x) for x in toolkit_version.split(".")] | ||
| if ver_parts > tv_parts: | ||
| toolkit_version = ver | ||
|
|
||
| cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version) | ||
| else: | ||
| cuda_toolkit(**_module_tag_to_dict(toolkit)) | ||
|
|
||
| for alias_tag in platform_aliases: | ||
| # Create a repository for each alias tag | ||
| platform_alias_repo( | ||
| name = alias_tag.name, | ||
| repo_name = alias_tag.name, | ||
| component_name = alias_tag.component_name, | ||
| linux_x86_64_repo = alias_tag.linux_x86_64_repo, | ||
| linux_aarch64_repo = alias_tag.linux_aarch64_repo, | ||
| linux_sbsa_repo = alias_tag.linux_sbsa_repo, | ||
| versions = alias_tag.versions, | ||
| ) | ||
|
|
||
| toolchain = module_extension( | ||
| implementation = _impl, | ||
| tag_classes = { | ||
| "component": cuda_component_tag, | ||
| "redist_json": cuda_redist_json_tag, | ||
| "toolkit": cuda_toolkit_tag, | ||
| "platform_alias": platform_alias_tag, | ||
| }, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would a
platform_mappingwithattr.string_dictbe better here? Validating thekeyin the rule impl seem to be future proof.