Skip to content

Commit 7d4b7bb

Browse files
authored
Fix flaky JaxLayer test. (#20756)
The `DTypePolicy` test produces lower precision results.
1 parent e010829 commit 7d4b7bb

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

keras/src/utils/jax_layer_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,13 @@ def verify_identical_model(model):
324324
model2.export(path, format="tf_saved_model")
325325
model4 = tf.saved_model.load(path)
326326
output4 = model4.serve(x_test)
327+
# The output difference is greater when using the GPU or bfloat16
328+
lower_precision = testing.jax_uses_gpu() or "dtype" in layer_init_kwargs
327329
self.assertAllClose(
328330
output1,
329331
output4,
330-
# The output difference might be significant when using the GPU
331-
atol=1e-2 if testing.jax_uses_gpu() else 1e-6,
332+
atol=1e-2 if lower_precision else 1e-6,
333+
rtol=1e-3 if lower_precision else 1e-6,
332334
)
333335

334336
# test subclass model building without a build method

0 commit comments

Comments
 (0)