Skip to content

Commit 3c3d6ad

Browse files
Remove the unused jax enable_x64. (keras-team#21737)
* Remove the unused jax x64 context. * Fix ops.trace dtype.
1 parent 232cf26 commit 3c3d6ad

File tree

7 files changed

+165
-439
lines changed

7 files changed

+165
-439
lines changed

keras/src/backend/common/variables_test.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -971,36 +971,17 @@ def test_mul(self, dtypes):
971971
def test_truediv(self, dtypes):
972972
import jax.numpy as jnp
973973

974-
try:
975-
# JAX v0.8.0 and newer
976-
from jax import enable_x64
977-
except ImportError:
978-
# JAX v0.7.2 and older
979-
from jax.experimental import enable_x64
980-
981-
# We have to disable x64 for jax since jnp.true_divide doesn't respect
982-
# JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast
983-
# the expected dtype from 64 bit to 32 bit when using jax backend.
984-
with enable_x64(False):
985-
dtype1, dtype2 = dtypes
986-
x1 = backend.Variable(
987-
"ones", shape=(1,), dtype=dtype1, trainable=False
988-
)
989-
x2 = backend.Variable(
990-
"ones", shape=(1,), dtype=dtype2, trainable=False
991-
)
992-
x1_jax = jnp.ones((1,), dtype=dtype1)
993-
x2_jax = jnp.ones((1,), dtype=dtype2)
994-
expected_dtype = standardize_dtype(
995-
jnp.true_divide(x1_jax, x2_jax).dtype
996-
)
997-
if "float64" in (dtype1, dtype2):
998-
expected_dtype = "float64"
999-
if backend.backend() == "jax":
1000-
expected_dtype = expected_dtype.replace("64", "32")
974+
dtype1, dtype2 = dtypes
975+
x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False)
976+
x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False)
977+
x1_jax = jnp.ones((1,), dtype=dtype1)
978+
x2_jax = jnp.ones((1,), dtype=dtype2)
979+
expected_dtype = standardize_dtype(
980+
jnp.true_divide(x1_jax, x2_jax).dtype
981+
)
1001982

1002-
self.assertDType(x1 / x2, expected_dtype)
1003-
self.assertDType(x1.__rtruediv__(x2), expected_dtype)
983+
self.assertDType(x1 / x2, expected_dtype)
984+
self.assertDType(x1.__rtruediv__(x2), expected_dtype)
1004985

1005986
@parameterized.named_parameters(
1006987
named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))

keras/src/backend/jax/numpy.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,14 +1240,7 @@ def tile(x, repeats):
12401240

12411241
def trace(x, offset=0, axis1=0, axis2=1):
12421242
x = convert_to_tensor(x)
1243-
dtype = None
1244-
# TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
1245-
# for both CPU & GPU environments.
1246-
# uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
1247-
# otherwise.
1248-
if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
1249-
dtype = "int32"
1250-
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
1243+
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
12511244

12521245

12531246
def tri(N, M=None, k=0, dtype=None):

keras/src/backend/numpy/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,8 +1177,10 @@ def trace(x, offset=0, axis1=0, axis2=1):
11771177
axis2 = standardize_axis_for_numpy(axis2)
11781178
x = convert_to_tensor(x)
11791179
dtype = standardize_dtype(x.dtype)
1180-
if dtype not in ("int64", "uint32", "uint64"):
1181-
dtype = dtypes.result_type(dtype, "int32")
1180+
if dtype in ("bool", "int8", "int16"):
1181+
dtype = "int32"
1182+
elif dtype in ("uint8", "uint16"):
1183+
dtype = "uint32"
11821184
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
11831185

11841186

keras/src/backend/tensorflow/numpy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,8 +2701,11 @@ def tile(x, repeats):
27012701
def trace(x, offset=0, axis1=0, axis2=1):
27022702
x = convert_to_tensor(x)
27032703
dtype = standardize_dtype(x.dtype)
2704-
if dtype not in ("int64", "uint32", "uint64"):
2705-
dtype = dtypes.result_type(dtype, "int32")
2704+
if dtype in ("bool", "int8", "int16"):
2705+
dtype = "int32"
2706+
elif dtype in ("uint8", "uint16"):
2707+
dtype = "uint32"
2708+
x = tf.cast(x, dtype)
27062709
x_shape = tf.shape(x)
27072710
x = moveaxis(x, (axis1, axis2), (-2, -1))
27082711
# Mask out the diagonal and reduce.
@@ -2711,10 +2714,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
27112714
x,
27122715
tf.zeros_like(x),
27132716
)
2714-
# The output dtype is set to "int32" if the input dtype is "bool"
2715-
if standardize_dtype(x.dtype) == "bool":
2716-
x = tf.cast(x, "int32")
2717-
return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
2717+
return tf.reduce_sum(x, axis=(-2, -1))
27182718

27192719

27202720
def tri(N, M=None, k=0, dtype=None):

keras/src/backend/torch/numpy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,8 +1646,9 @@ def tile(x, repeats):
16461646
def trace(x, offset=0, axis1=0, axis2=1):
16471647
x = convert_to_tensor(x)
16481648
dtype = standardize_dtype(x.dtype)
1649-
if dtype != "int64":
1650-
dtype = dtypes.result_type(dtype, "int32")
1649+
if dtype in ("bool", "int8", "int16", "uint8"):
1650+
# Torch backend doesn't support uint32 dtype.
1651+
dtype = "int32"
16511652
return torch.sum(
16521653
torch.diagonal(x, offset, axis1, axis2),
16531654
dim=-1,

keras/src/ops/numpy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6365,8 +6365,13 @@ def compute_output_spec(self, x):
63656365
x_shape[self.axis2] = -1
63666366
output_shape = list(filter((-1).__ne__, x_shape))
63676367
output_dtype = backend.standardize_dtype(x.dtype)
6368-
if output_dtype not in ("int64", "uint32", "uint64"):
6369-
output_dtype = dtypes.result_type(output_dtype, "int32")
6368+
if output_dtype in ("bool", "int8", "int16"):
6369+
output_dtype = "int32"
6370+
elif output_dtype in ("uint8", "uint16"):
6371+
output_dtype = "uint32"
6372+
if output_dtype == "uint32" and backend.backend() == "torch":
6373+
# Torch backend doesn't support uint32 dtype.
6374+
output_dtype = "int32"
63706375
return KerasTensor(output_shape, dtype=output_dtype)
63716376

63726377

0 commit comments

Comments
 (0)