We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
JaxLayer
1 parent e010829 commit 7d4b7bbCopy full SHA for 7d4b7bb
keras/src/utils/jax_layer_test.py
@@ -324,11 +324,13 @@ def verify_identical_model(model):
324
model2.export(path, format="tf_saved_model")
325
model4 = tf.saved_model.load(path)
326
output4 = model4.serve(x_test)
327
+ # The output difference is greater when using the GPU or bfloat16
328
+ lower_precision = testing.jax_uses_gpu() or "dtype" in layer_init_kwargs
329
self.assertAllClose(
330
output1,
331
output4,
- # The output difference might be significant when using the GPU
- atol=1e-2 if testing.jax_uses_gpu() else 1e-6,
332
+ atol=1e-2 if lower_precision else 1e-6,
333
+ rtol=1e-3 if lower_precision else 1e-6,
334
)
335
336
# test subclass model building without a build method
0 commit comments