Skip to content

Simplify the Bazel logic used to add dependencies to tests. #28753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,22 @@ string_flag(
)

config_setting(
name = "disable_jaxlib_and_jax_build",
name = "config_build_jax_true",
flag_values = {
":build_jax": "true",
},
)

config_setting(
name = "config_build_jax_false",
flag_values = {
":build_jaxlib": "false",
":build_jax": "false",
},
)

config_setting(
name = "enable_jaxlib_and_jax_py_import",
name = "config_build_jax_wheel",
flag_values = {
":build_jaxlib": "wheel",
":build_jax": "wheel",
},
)
Expand Down
33 changes: 0 additions & 33 deletions jax_plugins/cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

load(
"//jaxlib:jax.bzl",
"if_windows",
"py_library_providing_imports_info",
"pytype_library",
)
Expand Down Expand Up @@ -58,35 +57,3 @@ py_library_providing_imports_info(
data = [":pjrt_c_api_gpu_plugin.so"],
lib_rule = pytype_library,
)

config_setting(
name = "disable_jaxlib_for_cpu_build",
flag_values = {
"//jax:build_jaxlib": "false",
"@local_config_cuda//:enable_cuda": "False",
},
)

config_setting(
name = "disable_jaxlib_for_cuda12_build",
flag_values = {
"//jax:build_jaxlib": "false",
"@local_config_cuda//:enable_cuda": "True",
},
)

config_setting(
name = "enable_py_import_for_cpu_build",
flag_values = {
"//jax:build_jaxlib": "wheel",
"@local_config_cuda//:enable_cuda": "False",
},
)

config_setting(
name = "enable_py_import_for_cuda12_build",
flag_values = {
"//jax:build_jaxlib": "wheel",
"@local_config_cuda//:enable_cuda": "True",
},
)
1 change: 0 additions & 1 deletion jax_plugins/rocm/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ licenses(["notice"])

load(
"//jaxlib:jax.bzl",
"if_windows",
"py_library_providing_imports_info",
"pytype_library",
)
Expand Down
114 changes: 39 additions & 75 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -167,70 +167,36 @@ def if_building_jaxlib(
if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels
if_not_building: the wheels to depend on if we are not depending directly on //jaxlib.
"""

return select({
"//jax:config_build_jaxlib_true": if_building,
"//jax:config_build_jaxlib_false": if_not_building,
"//jax:config_build_jaxlib_wheel": [],
})

def _get_test_deps(deps, backend_independent):
"""Returns the test deps for the given backend.

Args:
deps: the full list of test dependencies
backend_independent: whether the test is backend independent
def _cpu_test_deps():
"""Returns the test depencies needed for a CPU-only JAX test."""
return select({
"//jax:config_build_jaxlib_true": [],
"//jax:config_build_jaxlib_false": ["@pypi//jaxlib"],
"//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"],
})

Returns:
A list of test deps for the given backend.
For CPU builds:
If --//jax:build_jaxlib=true, returns pypi test deps.
If --//jax:build_jaxlib=false, returns jaxlib pypi wheel dep and pypi test deps.
If --//jax:build_jaxlib=wheel, returns jaxlib py_import dep and pypi test deps.
For GPU builds:
If --//jax:build_jaxlib=true, returns pypi test deps and gpu build deps.
If --//jax:build_jaxlib=false, returns jaxlib, jax-cuda-plugin,
jax-cuda-pjrt pypi wheel deps and pypi test deps.
If --//jax:build_jaxlib=wheel, returns jaxlib,
jax-cuda-plugin, jax-cuda-pjrt py_import deps and pypi test deps.
"""
gpu_build_deps = [
"//jaxlib/cuda:gpu_only_test_deps",
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
]
pypi_test_deps = [d for d in deps if d.startswith("@pypi//")]

gpu_py_imports = [
"//jaxlib/tools:jaxlib_py_import",
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
] + pypi_test_deps
cpu_py_imports = [
"//jaxlib/tools:jaxlib_py_import",
] + pypi_test_deps
jaxlib_pypi_wheel_deps = [
"@pypi//jaxlib",
] + pypi_test_deps

if backend_independent:
test_deps = pypi_test_deps
gpu_pypi_wheel_deps = jaxlib_pypi_wheel_deps
gpu_py_import_deps = cpu_py_imports
else:
test_deps = gpu_build_deps + pypi_test_deps
gpu_pypi_wheel_deps = jaxlib_pypi_wheel_deps + [
def _gpu_test_deps():
"""Returns the additional dependencies needed for a GPU test."""
return select({
"//jax:config_build_jaxlib_true": [
"//jaxlib/cuda:gpu_only_test_deps",
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
],
"//jax:config_build_jaxlib_false": [
"@pypi//jax_cuda12_plugin",
"@pypi//jax_cuda12_pjrt",
]
gpu_py_import_deps = gpu_py_imports

return select({
"//jax:config_build_jaxlib_true": test_deps,
"//jax_plugins/cuda:disable_jaxlib_for_cpu_build": jaxlib_pypi_wheel_deps,
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps,
"//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports,
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps,
],
"//jax:config_build_jaxlib_wheel": [
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
],
})

def _get_jax_test_deps(deps):
Expand All @@ -246,28 +212,23 @@ def _get_jax_test_deps(deps):
If --//jax:build_jax=false, returns jax pypi wheel dep and transitive pypi test deps.
If --//jax:build_jax=wheel, returns jax py_import dep and transitive pypi test deps.
"""
jax_build_deps = [d for d in deps if not d.startswith("@pypi//")]
non_pypi_deps = [d for d in deps if not d.startswith("@pypi//")]

# A lot of tests don't have explicit dependencies on scipy, ml_dtypes, etc. But the tests
# transitively depends on them via //jax. So we need to make sure that these dependencies are
# included in the test when JAX is built from source.
jax_transitive_pypi_test_deps = {k: "true" for k in py_deps([
pypi_deps = depset([d for d in deps if d.startswith("@pypi//")])
pypi_deps = depset(py_deps([
"ml_dtypes",
"scipy",
"opt_einsum",
"flatbuffers",
])}
]), transitive = [pypi_deps]).to_list()

# Remove the pypi deps that are already provided by _get_test_deps().
for d in deps:
if d.startswith("@pypi//") and jax_transitive_pypi_test_deps.get(d):
jax_transitive_pypi_test_deps.pop(d)
return select({
"//jax:disable_jaxlib_and_jax_build": ["//:jax_wheel_with_internal_test_util"] +
jax_transitive_pypi_test_deps.keys(),
"//jax:enable_jaxlib_and_jax_py_import": ["//:jax_py_import"] +
jax_transitive_pypi_test_deps.keys(),
"//conditions:default": jax_build_deps + jax_transitive_pypi_test_deps.keys(),
return pypi_deps + select({
"//jax:config_build_jax_false": ["//:jax_wheel_with_internal_test_util"],
"//jax:config_build_jax_wheel": ["//:jax_py_import"],
"//jax:config_build_jax_true": non_pypi_deps,
})

# buildifier: disable=function-docstring
Expand Down Expand Up @@ -316,18 +277,21 @@ def jax_multiplatform_test(
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags.append("manual")
test_deps = _cpu_test_deps() + _get_jax_test_deps([
"//jax",
"//jax:test_util",
] + deps)
if backend == "gpu":
test_deps += _gpu_test_deps()
test_tags += tf_cuda_tests_tags()
elif backend == "tpu":
test_deps += ["@pypi//libtpu"]
native.py_test(
name = name + "_" + backend,
srcs = srcs,
args = test_args,
env = env,
deps = _get_test_deps(deps, backend_independent = False) +
_get_jax_test_deps([
"//jax",
"//jax:test_util",
] + deps),
deps = test_deps,
data = data,
shard_count = test_shards,
tags = test_tags,
Expand Down Expand Up @@ -620,13 +584,13 @@ def jax_py_test(
env = dict(env)
env.setdefault("PYTHONWARNINGS", "error")
deps = kwargs.get("deps", [])
test_deps = _get_test_deps(deps, backend_independent = True) + _get_jax_test_deps(deps)
test_deps = _cpu_test_deps() + _get_jax_test_deps(deps)
kwargs["deps"] = test_deps
py_test(name = name, env = env, **kwargs)

def pytype_test(name, **kwargs):
deps = kwargs.get("deps", [])
test_deps = _get_test_deps(deps, backend_independent = True) + _get_jax_test_deps(deps)
test_deps = _cpu_test_deps() + _get_jax_test_deps(deps)
kwargs["deps"] = test_deps
native.py_test(name = name, **kwargs)

Expand Down
Loading