Skip to content

Commit dc5e42c

Browse files
fix sas metrics in jax fit (#21765)
1 parent 1ba3b8f commit dc5e42c

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import jax.experimental.sparse as jax_sparse
55
import jax.numpy as jnp
6+
from jax import core
67
from jax import export as jax_export
78

89
from keras.src.backend import config
@@ -1001,6 +1002,11 @@ def ndim(x):
10011002

10021003

10031004
def nonzero(x):
1005+
if isinstance(x, core.Tracer):
1006+
# needed because this is called for several metric calculations,
1007+
# which will supply tracer values during `fit` execution
1008+
return jnp.nonzero(x, size=core.get_aval(x).size)[0]
1009+
10041010
return jnp.nonzero(x)
10051011

10061012

keras/src/metrics/confusion_metrics_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,20 @@ def test_invalid_num_thresholds(self):
787787
):
788788
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
789789

790+
@pytest.mark.requires_trainable_backend
791+
def test_handles_sas_metrics(self):
792+
# Test for https://github.com/keras-team/keras/issues/19376
793+
model = models.Sequential(
794+
[
795+
layers.Input((1,)),
796+
layers.Dense(1),
797+
]
798+
)
799+
sas = metrics.SpecificityAtSensitivity(0.5, name="sas")
800+
801+
model.compile(optimizer="adam", loss="crossentropy", metrics=[sas])
802+
model.fit(np.ones((5, 1)), np.ones((5, 1)))
803+
790804

791805
class SpecificityAtSensitivityTest(testing.TestCase):
792806
def test_config(self):

0 commit comments

Comments
 (0)