Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jax.experimental.sparse as jax_sparse
import jax.numpy as jnp
from jax import core
from jax import export as jax_export

from keras.src.backend import config
Expand Down Expand Up @@ -1002,11 +1001,6 @@ def ndim(x):


def nonzero(x):
if isinstance(x, core.Tracer):
# needed because this is called for several metric calculations,
# which will supply tracer values during `fit` execution
return jnp.nonzero(x, size=core.get_aval(x).size)[0]

return jnp.nonzero(x)


Expand Down
13 changes: 7 additions & 6 deletions keras/src/metrics/confusion_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
Args:
constrained: Over these values the constraint is specified. A rank-1
tensor.
dependent: From these values the maximum that satiesfies the
dependent: From these values the maximum that satisfies the
constraint is selected. Values in this tensor and in
`constrained` are linked by having the same threshold at each
position, hence this tensor must have the same shape.
Expand All @@ -664,11 +664,12 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
Returns:
maximal dependent value, if no value satisfies the constraint 0.0.
"""
feasible = ops.nonzero(predicate(constrained, self.value))
feasible_exists = ops.greater(ops.size(feasible), 0)
max_dependent = ops.max(ops.take(dependent, feasible), initial=0)

return ops.where(feasible_exists, max_dependent, 0.0)
feasible = predicate(constrained, self.value)
# Mask values based on whether they satisfy the constraint and take max.
return ops.max(
ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
initial=0,
)
Comment on lines +669 to +672
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation is correct and makes the function JIT-compatible. For slightly improved readability, you could consider using ops.where to explicitly mask the dependent tensor. This avoids the implicit boolean-to-float conversion from ops.cast and can make the intent clearer.

Suggested change
return ops.max(
ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
initial=0,
)
return ops.max(
ops.where(feasible, dependent, 0.0),
initial=0,
)



@keras_export("keras.metrics.SensitivityAtSpecificity")
Expand Down