@@ -654,7 +654,7 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
654654        Args: 
655655            constrained: Over these values the constraint is specified. A rank-1 
656656                tensor. 
657-             dependent: From these values the maximum that satiesfies  the 
657+             dependent: From these values the maximum that satisfies  the 
658658                constraint is selected. Values in this tensor and in 
659659                `constrained` are linked by having the same threshold at each 
660660                position, hence this tensor must have the same shape. 
@@ -664,11 +664,12 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
664664        Returns: 
665665            maximal dependent value, if no value satisfies the constraint 0.0. 
666666        """ 
667-         feasible  =  ops .nonzero (predicate (constrained , self .value ))
668-         feasible_exists  =  ops .greater (ops .size (feasible ), 0 )
669-         max_dependent  =  ops .max (ops .take (dependent , feasible ), initial = 0 )
670- 
671-         return  ops .where (feasible_exists , max_dependent , 0.0 )
667+         feasible  =  predicate (constrained , self .value )
668+         # Mask values based on whether they satisfy the constraint and take max. 
669+         return  ops .max (
670+             ops .multiply (dependent , ops .cast (feasible , dependent .dtype )),
671+             initial = 0 ,
672+         )
672673
673674
674675@keras_export ("keras.metrics.SensitivityAtSpecificity" ) 
0 commit comments