Skip to content

Commit 965c4a2

Browse files
authored
Add from_sorted_ids option to SparseTopKCategoricalAccuracy. (#20433)
to consume sorted IDs of top N categories instead of scores for all categories.
1 parent 2be521f commit 965c4a2

File tree

2 files changed

+123
-8
lines changed

2 files changed

+123
-8
lines changed

keras/src/metrics/accuracy_metrics.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,32 @@ def get_config(self):
380380

381381

382382
@keras_export("keras.metrics.sparse_top_k_categorical_accuracy")
383-
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
383+
def sparse_top_k_categorical_accuracy(
384+
y_true, y_pred, k=5, from_sorted_ids=False
385+
):
386+
"""Computes how often integer targets are in the top `K` predictions.
387+
388+
Args:
389+
y_true: A tensor of shape `(batch_size)` representing indices or IDs of
390+
true categories.
391+
y_pred: If `from_sorted_ids=False`, a tensor of shape
392+
`(batch_size, num_categories)` containing the scores for each sample
393+
for all possible categories. If `from_sorted_ids=True`, a tensor of
394+
shape `(batch_size, N)` containing indices or IDs of the top `N`
395+
categories in order from highest score to lowest score.
396+
k: (Optional) Number of top elements to look at for computing accuracy.
397+
Defaults to `5`.
398+
from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or
399+
scores for all categories (the default).
400+
401+
Returns:
402+
A tensor with the same shape as `y_true` containing ones where `y_true`
403+
is in the top `k` and zeros elsewhere.
404+
"""
384405
reshape_matches = False
385406
y_pred = ops.convert_to_tensor(y_pred)
386-
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
407+
y_true_dtype = y_pred.dtype if from_sorted_ids else "int32"
408+
y_true = ops.convert_to_tensor(y_true, dtype=y_true_dtype)
387409
y_true_rank = len(y_true.shape)
388410
y_pred_rank = len(y_pred.shape)
389411
y_true_org_shape = ops.shape(y_true)
@@ -396,10 +418,16 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
396418
reshape_matches = True
397419
y_true = ops.reshape(y_true, [-1])
398420

399-
matches = ops.cast(
400-
ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k),
401-
dtype=backend.floatx(),
402-
)
421+
if from_sorted_ids:
422+
# By slicing the first k items, we assume they are sorted by score.
423+
# Reduce with `any` to count multiple matches only once.
424+
matches = ops.any(
425+
ops.equal(ops.expand_dims(y_true, axis=1), y_pred[:, :k]), axis=1
426+
)
427+
else:
428+
matches = ops.in_top_k(y_true, y_pred, k=k)
429+
430+
matches = ops.cast(matches, dtype=backend.floatx())
403431

404432
# returned matches is expected to have same shape as y_true input
405433
if reshape_matches:
@@ -412,11 +440,33 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
412440
class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
413441
"""Computes how often integer targets are in the top `K` predictions.
414442
443+
By default, the arguments expected by `update_state()` are:
444+
- `y_true`: a tensor of shape `(batch_size)` representing indices of true
445+
categories.
446+
- `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the
447+
scores for each sample for all possible categories.
448+
449+
With `from_sorted_ids=True`, the arguments expected by `update_state` are:
450+
- `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of
451+
true categories.
452+
- `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or
453+
IDs of the top `N` categories sorted in order from highest score to
454+
lowest score. `N` must be greater or equal to `k`.
455+
456+
The `from_sorted_ids=True` option can be more efficient when the set of
457+
categories is very large and the model has an optimized way to retrieve the
458+
top ones either without scoring or without maintaining the scores for all
459+
the possible categories.
460+
415461
Args:
416462
k: (Optional) Number of top elements to look at for computing accuracy.
417463
Defaults to `5`.
418464
name: (Optional) string name of the metric instance.
419465
dtype: (Optional) data type of the metric result.
466+
from_sorted_ids: (Optional) When `False`, the default, the tensor passed
467+
in `y_pred` contains the unsorted scores of all possible categories.
468+
When `True`, `y_pred` contains a the indices or IDs for the top
469+
categories.
420470
421471
Example:
422472
@@ -431,6 +481,12 @@ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
431481
>>> m.result()
432482
0.3
433483
484+
>>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1,
485+
... from_sorted_ids=True)
486+
>>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]])
487+
>>> m.result()
488+
0.5
489+
434490
Usage with `compile()` API:
435491
436492
```python
@@ -441,17 +497,26 @@ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
441497
"""
442498

443499
def __init__(
444-
self, k=5, name="sparse_top_k_categorical_accuracy", dtype=None
500+
self,
501+
k=5,
502+
name="sparse_top_k_categorical_accuracy",
503+
dtype=None,
504+
from_sorted_ids=False,
445505
):
446506
super().__init__(
447507
fn=sparse_top_k_categorical_accuracy,
448508
name=name,
449509
dtype=dtype,
450510
k=k,
511+
from_sorted_ids=from_sorted_ids,
451512
)
452513
self.k = k
514+
self.from_sorted_ids = from_sorted_ids
453515
# Metric should be maximized during optimization.
454516
self._direction = "up"
455517

456518
def get_config(self):
457-
return {"name": self.name, "dtype": self.dtype, "k": self.k}
519+
config = {"name": self.name, "dtype": self.dtype, "k": self.k}
520+
if self.from_sorted_ids:
521+
config["from_sorted_ids"] = True
522+
return config

keras/src/metrics/accuracy_metrics_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,27 @@ def test_config(self):
440440
self.assertEqual(len(sp_top_k_cat_acc_obj2.variables), 2)
441441
self.assertEqual(sp_top_k_cat_acc_obj2._dtype, "float32")
442442
self.assertEqual(sp_top_k_cat_acc_obj2.k, 1)
443+
self.assertFalse(sp_top_k_cat_acc_obj2.from_sorted_ids)
444+
445+
def test_config_from_sorted_ids(self):
446+
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
447+
k=1,
448+
name="sparse_top_k_categorical_accuracy",
449+
dtype="float32",
450+
from_sorted_ids=True,
451+
)
452+
453+
# Test get_config
454+
sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config()
455+
self.assertTrue(sp_top_k_cat_acc_obj_config["from_sorted_ids"])
456+
457+
# Check save and restore config
458+
sp_top_k_cat_acc_obj2 = (
459+
accuracy_metrics.SparseTopKCategoricalAccuracy.from_config(
460+
sp_top_k_cat_acc_obj_config
461+
)
462+
)
463+
self.assertTrue(sp_top_k_cat_acc_obj2.from_sorted_ids)
443464

444465
def test_unweighted(self):
445466
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
@@ -463,3 +484,32 @@ def test_weighted(self):
463484
)
464485
result = sp_top_k_cat_acc_obj.result()
465486
self.assertAllClose(result, 0.3, atol=1e-3)
487+
488+
def test_from_sorted_ids_unweighted(self):
489+
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
490+
k=1,
491+
name="sparse_top_k_categorical_accuracy",
492+
dtype="float32",
493+
from_sorted_ids=True,
494+
)
495+
y_true = np.array([2, 1])
496+
y_pred = np.array([[1, 0, 3], [1, 2, 3]])
497+
sp_top_k_cat_acc_obj.update_state(y_true, y_pred)
498+
result = sp_top_k_cat_acc_obj.result()
499+
self.assertAllClose(result, 0.5, atol=1e-3)
500+
501+
def test_from_sorted_ids_weighted(self):
502+
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
503+
k=1,
504+
name="sparse_top_k_categorical_accuracy",
505+
dtype="float32",
506+
from_sorted_ids=True,
507+
)
508+
y_true = np.array([2, 1])
509+
y_pred = np.array([[1, 0, 3], [1, 2, 3]])
510+
sample_weight = np.array([0.7, 0.3])
511+
sp_top_k_cat_acc_obj.update_state(
512+
y_true, y_pred, sample_weight=sample_weight
513+
)
514+
result = sp_top_k_cat_acc_obj.result()
515+
self.assertAllClose(result, 0.3, atol=1e-3)

0 commit comments

Comments
 (0)