@@ -380,10 +380,32 @@ def get_config(self):
380380
381381
382382@keras_export ("keras.metrics.sparse_top_k_categorical_accuracy" )
383- def sparse_top_k_categorical_accuracy (y_true , y_pred , k = 5 ):
383+ def sparse_top_k_categorical_accuracy (
384+ y_true , y_pred , k = 5 , from_sorted_ids = False
385+ ):
386+ """Computes how often integer targets are in the top `K` predictions.
387+
388+ Args:
389+ y_true: A tensor of shape `(batch_size)` representing indices or IDs of
390+ true categories.
391+ y_pred: If `from_sorted_ids=False`, a tensor of shape
392+ `(batch_size, num_categories)` containing the scores for each sample
393+ for all possible categories. If `from_sorted_ids=True`, a tensor of
394+ shape `(batch_size, N)` containing indices or IDs of the top `N`
395+ categories in order from highest score to lowest score.
396+ k: (Optional) Number of top elements to look at for computing accuracy.
397+ Defaults to `5`.
398+ from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or
399+ scores for all categories (the default).
400+
401+ Returns:
402+ A tensor with the same shape as `y_true` containing ones where `y_true`
403+ is in the top `k` and zeros elsewhere.
404+ """
384405 reshape_matches = False
385406 y_pred = ops .convert_to_tensor (y_pred )
386- y_true = ops .convert_to_tensor (y_true , dtype = y_pred .dtype )
407+ y_true_dtype = y_pred .dtype if from_sorted_ids else "int32"
408+ y_true = ops .convert_to_tensor (y_true , dtype = y_true_dtype )
387409 y_true_rank = len (y_true .shape )
388410 y_pred_rank = len (y_pred .shape )
389411 y_true_org_shape = ops .shape (y_true )
@@ -396,10 +418,16 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
396418 reshape_matches = True
397419 y_true = ops .reshape (y_true , [- 1 ])
398420
399- matches = ops .cast (
400- ops .in_top_k (ops .cast (y_true , "int32" ), y_pred , k = k ),
401- dtype = backend .floatx (),
402- )
421+ if from_sorted_ids :
422+ # By slicing the first k items, we assume they are sorted by score.
423+ # Reduce with `any` to count multiple matches only once.
424+ matches = ops .any (
425+ ops .equal (ops .expand_dims (y_true , axis = 1 ), y_pred [:, :k ]), axis = 1
426+ )
427+ else :
428+ matches = ops .in_top_k (y_true , y_pred , k = k )
429+
430+ matches = ops .cast (matches , dtype = backend .floatx ())
403431
404432 # returned matches is expected to have same shape as y_true input
405433 if reshape_matches :
@@ -412,11 +440,33 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
412440class SparseTopKCategoricalAccuracy (reduction_metrics .MeanMetricWrapper ):
413441 """Computes how often integer targets are in the top `K` predictions.
414442
443+ By default, the arguments expected by `update_state()` are:
444+ - `y_true`: a tensor of shape `(batch_size)` representing indices of true
445+ categories.
446+ - `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the
447+ scores for each sample for all possible categories.
448+
449+ With `from_sorted_ids=True`, the arguments expected by `update_state` are:
450+ - `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of
451+ true categories.
452+ - `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or
453+ IDs of the top `N` categories sorted in order from highest score to
454+ lowest score. `N` must be greater or equal to `k`.
455+
456+ The `from_sorted_ids=True` option can be more efficient when the set of
457+ categories is very large and the model has an optimized way to retrieve the
458+ top ones either without scoring or without maintaining the scores for all
459+ the possible categories.
460+
415461 Args:
416462 k: (Optional) Number of top elements to look at for computing accuracy.
417463 Defaults to `5`.
418464 name: (Optional) string name of the metric instance.
419465 dtype: (Optional) data type of the metric result.
466+ from_sorted_ids: (Optional) When `False`, the default, the tensor passed
467+ in `y_pred` contains the unsorted scores of all possible categories.
468+ When `True`, `y_pred` contains a the indices or IDs for the top
469+ categories.
420470
421471 Example:
422472
@@ -431,6 +481,12 @@ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
431481 >>> m.result()
432482 0.3
433483
484+ >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1,
485+ ... from_sorted_ids=True)
486+ >>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]])
487+ >>> m.result()
488+ 0.5
489+
434490 Usage with `compile()` API:
435491
436492 ```python
@@ -441,17 +497,26 @@ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
441497 """
442498
443499 def __init__ (
444- self , k = 5 , name = "sparse_top_k_categorical_accuracy" , dtype = None
500+ self ,
501+ k = 5 ,
502+ name = "sparse_top_k_categorical_accuracy" ,
503+ dtype = None ,
504+ from_sorted_ids = False ,
445505 ):
446506 super ().__init__ (
447507 fn = sparse_top_k_categorical_accuracy ,
448508 name = name ,
449509 dtype = dtype ,
450510 k = k ,
511+ from_sorted_ids = from_sorted_ids ,
451512 )
452513 self .k = k
514+ self .from_sorted_ids = from_sorted_ids
453515 # Metric should be maximized during optimization.
454516 self ._direction = "up"
455517
456518 def get_config (self ):
457- return {"name" : self .name , "dtype" : self .dtype , "k" : self .k }
519+ config = {"name" : self .name , "dtype" : self .dtype , "k" : self .k }
520+ if self .from_sorted_ids :
521+ config ["from_sorted_ids" ] = True
522+ return config
0 commit comments