Skip to content

Commit

Permalink
[tf.numpy] Fixes jax_tests/ breakages.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 487364206
  • Loading branch information
wangpengmit authored and copybara-github committed Nov 9, 2022
1 parent 050e0b6 commit 0b3ed2b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 6 additions & 2 deletions trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def onp_fun(a, b):
check_xla = not set((lhs_dtype, rhs_dtype)).intersection(
(onp.int32, onp.int64))

tol = {onp.float64: 1e-14}
tol = {onp.float64: 1e-14, onp.float16: 0.04, onp.complex128: 6e-15}
tol = max(jtu.tolerance(lhs_dtype, tol), jtu.tolerance(rhs_dtype, tol))
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True,
check_incomplete_shape=True,
Expand Down Expand Up @@ -1301,8 +1301,12 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, onp_op, lnp_op, rng_fact
tol=tol)
# XLA lacks int64 Cumsum/Cumprod kernels (b/168841378).
check_xla = out_dtype != onp.int64
rtol = None
if out_dtype == onp.float16:
rtol = 2e-3
self._CompileAndCheck(
lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True,
lnp_fun, args_maker, check_dtypes=True, rtol=rtol,
check_incomplete_shape=True,
check_experimental_compile=check_xla,
check_xla_forced_compile=check_xla)

Expand Down
4 changes: 3 additions & 1 deletion trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,10 @@ def format_test_name_suffix(opname, shapes, dtypes):

# We use special symbols, represented as singleton objects, to distinguish
# between NumPy scalars, Python scalars, and 0-D arrays.
class ScalarShape(object):
class ScalarShape:
def __len__(self): return 0
def __getitem__(self, i):
raise IndexError(f'index {i} out of range.')
class _NumpyScalar(ScalarShape): pass
class _PythonScalar(ScalarShape): pass
NUMPY_SCALAR_SHAPE = _NumpyScalar()
Expand Down

0 comments on commit 0b3ed2b

Please sign in to comment.