Skip to content

Commit 10b51ce

Browse files
authored
Make confusion metrics compilable. (#21775)
By removing the use of `ops.nonzero` which returns an array of non-predetermined size. Follow-up to https://github.com/keras-team/keras/pull/21765/files#r2462099871 Fixes #19376
1 parent 18f79d6 commit 10b51ce

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import jax.experimental.sparse as jax_sparse
55
import jax.numpy as jnp
6-
from jax import core
76
from jax import export as jax_export
87

98
from keras.src.backend import config
@@ -1002,11 +1001,6 @@ def ndim(x):
10021001

10031002

10041003
def nonzero(x):
1005-
if isinstance(x, core.Tracer):
1006-
# needed because this is called for several metric calculations,
1007-
# which will supply tracer values during `fit` execution
1008-
return jnp.nonzero(x, size=core.get_aval(x).size)[0]
1009-
10101004
return jnp.nonzero(x)
10111005

10121006

keras/src/metrics/confusion_metrics.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
654654
Args:
655655
constrained: Over these values the constraint is specified. A rank-1
656656
tensor.
657-
dependent: From these values the maximum that satiesfies the
657+
dependent: From these values the maximum that satisfies the
658658
constraint is selected. Values in this tensor and in
659659
`constrained` are linked by having the same threshold at each
660660
position, hence this tensor must have the same shape.
@@ -664,11 +664,12 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
664664
Returns:
665665
maximal dependent value, if no value satisfies the constraint 0.0.
666666
"""
667-
feasible = ops.nonzero(predicate(constrained, self.value))
668-
feasible_exists = ops.greater(ops.size(feasible), 0)
669-
max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
670-
671-
return ops.where(feasible_exists, max_dependent, 0.0)
667+
feasible = predicate(constrained, self.value)
668+
# Mask values based on whether they satisfy the constraint and take max.
669+
return ops.max(
670+
ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
671+
initial=0,
672+
)
672673

673674

674675
@keras_export("keras.metrics.SensitivityAtSpecificity")

0 commit comments

Comments
 (0)