Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable failing test cases when JAX_ENABLE_X64=1 in the Bazel CPU build #25443

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Disable failing test cases when JAX_ENABLE_X64=1 in the Bazel CPU b…
…uild

PiperOrigin-RevId: 705635799
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Dec 12, 2024
commit ecc2673e7b0ca234d5bcfabc8ed65dc9cbbc9b0f
3 changes: 3 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import jax
from jax import dtypes
from jax import lax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
from jax.test_util import check_grads
Expand Down Expand Up @@ -205,6 +206,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
))
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory(self.rng())
if jtu.test_device_matches(["cpu"]) and (op is lax.cosh or op is lax.cbrt) and config.enable_x64.value:
raise SkipTest("cosh and cbrt grad fails in x64 mode on CPU") # b/383756018
if jtu.test_device_matches(["tpu"]):
if op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
Expand Down
2 changes: 2 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def np_fun(x):
))
def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, promote_integers):
if jtu.test_device_matches(["cpu"]) and name == "sum" and config.enable_x64.value and dtype == np.float16:
raise unittest.SkipTest("sum op fails in x64 mode on CPU with dtype=float16") # b/383756018
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
Expand Down
Loading