Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 73 additions & 8 deletions keras/src/metrics/accuracy_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,32 @@ def get_config(self):


@keras_export("keras.metrics.sparse_top_k_categorical_accuracy")
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
def sparse_top_k_categorical_accuracy(
y_true, y_pred, k=5, from_sorted_ids=False
):
"""Computes how often integer targets are in the top `K` predictions.

Args:
y_true: A tensor of shape `(batch_size)` representing indices or IDs of
true categories.
y_pred: If `from_sorted_ids=False`, a tensor of shape
`(batch_size, num_categories)` containing the scores for each sample
for all possible categories. If `from_sorted_ids=True`, a tensor of
shape `(batch_size, N)` containing indices or IDs of the top `N`
categories in order from highest score to lowest score.
k: (Optional) Number of top elements to look at for computing accuracy.
Defaults to `5`.
from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or
scores for all categories (the default).

Returns:
A tensor with the same shape as `y_true` containing ones where `y_true`
is in the top `k` and zeros elsewhere.
"""
reshape_matches = False
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true_dtype = y_pred.dtype if from_sorted_ids else "int32"
y_true = ops.convert_to_tensor(y_true, dtype=y_true_dtype)
y_true_rank = len(y_true.shape)
y_pred_rank = len(y_pred.shape)
y_true_org_shape = ops.shape(y_true)
Expand All @@ -396,10 +418,16 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
reshape_matches = True
y_true = ops.reshape(y_true, [-1])

matches = ops.cast(
ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k),
dtype=backend.floatx(),
)
if from_sorted_ids:
# By slicing the first k items, we assume they are sorted by score.
# Reduce with `any` to count multiple matches only once.
matches = ops.any(
ops.equal(ops.expand_dims(y_true, axis=1), y_pred[:, :k]), axis=1
)
else:
matches = ops.in_top_k(y_true, y_pred, k=k)

matches = ops.cast(matches, dtype=backend.floatx())

# returned matches is expected to have same shape as y_true input
if reshape_matches:
Expand All @@ -412,11 +440,33 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
"""Computes how often integer targets are in the top `K` predictions.

By default, the arguments expected by `update_state()` are:
- `y_true`: a tensor of shape `(batch_size)` representing indices of true
categories.
- `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the
scores for each sample for all possible categories.

With `from_sorted_ids=True`, the arguments expected by `update_state` are:
- `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of
true categories.
- `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or
IDs of the top `N` categories sorted in order from highest score to
lowest score. `N` must be greater or equal to `k`.

The `from_sorted_ids=True` option can be more efficient when the set of
categories is very large and the model has an optimized way to retrieve the
top ones either without scoring or without maintaining the scores for all
the possible categories.

Args:
k: (Optional) Number of top elements to look at for computing accuracy.
Defaults to `5`.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
from_sorted_ids: (Optional) When `False`, the default, the tensor passed
in `y_pred` contains the unsorted scores of all possible categories.
When `True`, `y_pred` contains a the indices or IDs for the top
categories.

Example:

Expand All @@ -431,6 +481,12 @@ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
>>> m.result()
0.3

>>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1,
... from_sorted_ids=True)
>>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]])
>>> m.result()
0.5

Usage with `compile()` API:

```python
Expand All @@ -441,17 +497,26 @@ class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
"""

def __init__(
self, k=5, name="sparse_top_k_categorical_accuracy", dtype=None
self,
k=5,
name="sparse_top_k_categorical_accuracy",
dtype=None,
from_sorted_ids=False,
):
super().__init__(
fn=sparse_top_k_categorical_accuracy,
name=name,
dtype=dtype,
k=k,
from_sorted_ids=from_sorted_ids,
)
self.k = k
self.from_sorted_ids = from_sorted_ids
# Metric should be maximized during optimization.
self._direction = "up"

def get_config(self):
return {"name": self.name, "dtype": self.dtype, "k": self.k}
config = {"name": self.name, "dtype": self.dtype, "k": self.k}
if self.from_sorted_ids:
config["from_sorted_ids"] = True
return config
50 changes: 50 additions & 0 deletions keras/src/metrics/accuracy_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,27 @@ def test_config(self):
self.assertEqual(len(sp_top_k_cat_acc_obj2.variables), 2)
self.assertEqual(sp_top_k_cat_acc_obj2._dtype, "float32")
self.assertEqual(sp_top_k_cat_acc_obj2.k, 1)
self.assertFalse(sp_top_k_cat_acc_obj2.from_sorted_ids)

def test_config_from_sorted_ids(self):
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
k=1,
name="sparse_top_k_categorical_accuracy",
dtype="float32",
from_sorted_ids=True,
)

# Test get_config
sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config()
self.assertTrue(sp_top_k_cat_acc_obj_config["from_sorted_ids"])

# Check save and restore config
sp_top_k_cat_acc_obj2 = (
accuracy_metrics.SparseTopKCategoricalAccuracy.from_config(
sp_top_k_cat_acc_obj_config
)
)
self.assertTrue(sp_top_k_cat_acc_obj2.from_sorted_ids)

def test_unweighted(self):
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
Expand All @@ -463,3 +484,32 @@ def test_weighted(self):
)
result = sp_top_k_cat_acc_obj.result()
self.assertAllClose(result, 0.3, atol=1e-3)

def test_from_sorted_ids_unweighted(self):
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
k=1,
name="sparse_top_k_categorical_accuracy",
dtype="float32",
from_sorted_ids=True,
)
y_true = np.array([2, 1])
y_pred = np.array([[1, 0, 3], [1, 2, 3]])
sp_top_k_cat_acc_obj.update_state(y_true, y_pred)
result = sp_top_k_cat_acc_obj.result()
self.assertAllClose(result, 0.5, atol=1e-3)

def test_from_sorted_ids_weighted(self):
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
k=1,
name="sparse_top_k_categorical_accuracy",
dtype="float32",
from_sorted_ids=True,
)
y_true = np.array([2, 1])
y_pred = np.array([[1, 0, 3], [1, 2, 3]])
sample_weight = np.array([0.7, 0.3])
sp_top_k_cat_acc_obj.update_state(
y_true, y_pred, sample_weight=sample_weight
)
result = sp_top_k_cat_acc_obj.result()
self.assertAllClose(result, 0.3, atol=1e-3)