diff --git a/CHANGELOG.md b/CHANGELOG.md index e2722697b0d..1c6faa98f1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added warning to `PearsonCorrCoeff` if input has a very small variance for its given dtype ([#1926](https://github.com/Lightning-AI/torchmetrics/pull/1926)) +- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961)) + + ### Changed - Changed all non-task specific classification metrics to be true subtypes of `Metric` ([#1963](https://github.com/Lightning-AI/torchmetrics/pull/1963)) diff --git a/src/torchmetrics/functional/retrieval/reciprocal_rank.py b/src/torchmetrics/functional/retrieval/reciprocal_rank.py index 7a7af93e079..2a20cd0caec 100644 --- a/src/torchmetrics/functional/retrieval/reciprocal_rank.py +++ b/src/torchmetrics/functional/retrieval/reciprocal_rank.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from torch import Tensor, tensor from torchmetrics.utilities.checks import _check_retrieval_functional_inputs -def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: +def retrieval_reciprocal_rank(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor: """Compute reciprocal rank (for information retrieval). See `Mean Reciprocal Rank`_. ``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``, @@ -27,10 +29,15 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: Args: preds: estimated probabilities of each document to be relevant. target: ground truth about each document being relevant or not. + top_k: consider only the top k elements (default: ``None``, which considers them all) Return: a single-value tensor with the reciprocal rank (RR) of the predictions ``preds`` wrt the labels ``target``. + Raises: + ValueError: + If ``top_k`` is not ``None`` or an integer larger than 0. + Example: >>> from torchmetrics.functional.retrieval import retrieval_reciprocal_rank >>> preds = torch.tensor([0.2, 0.3, 0.5]) @@ -41,9 +48,13 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor: """ preds, target = _check_retrieval_functional_inputs(preds, target) + top_k = top_k or preds.shape[-1] + if not isinstance(top_k, int) and top_k <= 0: + raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}.") + + target = target[preds.topk(min(top_k, preds.shape[-1]), sorted=True, dim=-1)[1]] if not target.sum(): return tensor(0.0, device=preds.device) - target = target[torch.argsort(preds, dim=-1, descending=True)] position = torch.nonzero(target).view(-1) return 1.0 / (position[0] + 1.0) diff --git a/src/torchmetrics/retrieval/average_precision.py b/src/torchmetrics/retrieval/average_precision.py index 5322dfc8b6a..b9242dfc7b7 100644 --- a/src/torchmetrics/retrieval/average_precision.py +++ b/src/torchmetrics/retrieval/average_precision.py @@ -38,8 +38,8 @@ class RetrievalMAP(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``rmap`` (:class:`~torch.Tensor`): A tensor with the mean average precision of the predictions ``preds`` - w.r.t. the labels ``target`` + - ``map@k`` (:class:`~torch.Tensor`): A single-value tensor with the mean average precision (MAP) + of the predictions ``preds`` w.r.t. the labels ``target``. All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by @@ -54,9 +54,8 @@ class RetrievalMAP(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - ignore_index: - Ignore predictions where the target is equal to this number. - top_k: consider only the top k elements for each query (default: ``None``, which considers them all) + ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: ``None``, which considers them all) kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -65,7 +64,7 @@ class RetrievalMAP(RetrievalMetric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``top_k`` is not ``None`` or an integer larger than 0. + If ``top_k`` is not ``None`` or not an integer greater than 0. Example: >>> from torch import tensor diff --git a/src/torchmetrics/retrieval/fall_out.py b/src/torchmetrics/retrieval/fall_out.py index 2f633f25d89..7c4f031e1e6 100644 --- a/src/torchmetrics/retrieval/fall_out.py +++ b/src/torchmetrics/retrieval/fall_out.py @@ -40,7 +40,7 @@ class RetrievalFallOut(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``fo`` (:class:`~torch.Tensor`): A tensor with the computed metric + - ``fo@k`` (:class:`~torch.Tensor`): A tensor with the computed metric All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by @@ -55,9 +55,8 @@ class RetrievalFallOut(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - ignore_index: - Ignore predictions where the target is equal to this number. - top_k: consider only the top k elements for each query (default: `None`, which considers them all) + ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: `None`, which considers them all) kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -66,7 +65,7 @@ class RetrievalFallOut(RetrievalMetric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``top_k`` parameter is not `None` or an integer larger than 0. + If ``top_k`` is not ``None`` or not an integer greater than 0. Example: >>> from torchmetrics.retrieval import RetrievalFallOut diff --git a/src/torchmetrics/retrieval/hit_rate.py b/src/torchmetrics/retrieval/hit_rate.py index 232456684cd..dc13fec4499 100644 --- a/src/torchmetrics/retrieval/hit_rate.py +++ b/src/torchmetrics/retrieval/hit_rate.py @@ -38,7 +38,7 @@ class RetrievalHitRate(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``hr2`` (:class:`~torch.Tensor`): A single-value tensor with the hit rate (at ``top_k``) of the predictions + - ``hr@k`` (:class:`~torch.Tensor`): A single-value tensor with the hit rate (at ``top_k``) of the predictions ``preds`` w.r.t. the labels ``target`` All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, @@ -55,9 +55,8 @@ class RetrievalHitRate(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - ignore_index: - Ignore predictions where the target is equal to this number. - top_k: consider only the top k elements for each query (default: ``None``, which considers them all) + ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: ``None``, which considers them all) kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -66,7 +65,7 @@ class RetrievalHitRate(RetrievalMetric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``top_k`` parameter is not `None` or an integer larger than 0. + If ``top_k`` is not ``None`` or not an integer greater than 0. Example: >>> from torch import tensor diff --git a/src/torchmetrics/retrieval/ndcg.py b/src/torchmetrics/retrieval/ndcg.py index 2cee986ce14..9b917e9bffd 100644 --- a/src/torchmetrics/retrieval/ndcg.py +++ b/src/torchmetrics/retrieval/ndcg.py @@ -38,7 +38,7 @@ class RetrievalNormalizedDCG(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``ndcg`` (:class:`~torch.Tensor`): A single-value tensor with the nDCG of the predictions + - ``ndcg@k`` (:class:`~torch.Tensor`): A single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target`` All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, @@ -55,9 +55,8 @@ class RetrievalNormalizedDCG(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - ignore_index: - Ignore predictions where the target is equal to this number. - top_k: consider only the top k elements for each query (default: ``None``, which considers them all) + ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: ``None``, which considers them all) kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -66,7 +65,7 @@ class RetrievalNormalizedDCG(RetrievalMetric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``top_k`` parameter is not `None` or an integer larger than 0. + If ``top_k`` is not ``None`` or not an integer greater than 0. Example: >>> from torch import tensor diff --git a/src/torchmetrics/retrieval/precision.py b/src/torchmetrics/retrieval/precision.py index f2c31bc8e4f..b6606ef65b2 100644 --- a/src/torchmetrics/retrieval/precision.py +++ b/src/torchmetrics/retrieval/precision.py @@ -38,7 +38,7 @@ class RetrievalPrecision(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``p2`` (:class:`~torch.Tensor`): A single-value tensor with the precision (at ``top_k``) of the predictions + - ``p@k`` (:class:`~torch.Tensor`): A single-value tensor with the precision (at ``top_k``) of the predictions ``preds`` w.r.t. the labels ``target`` All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, @@ -54,10 +54,9 @@ class RetrievalPrecision(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - ignore_index: - Ignore predictions where the target is equal to this number. - top_k: consider only the top k elements for each query (default: ``None``, which considers them all) - adaptive_k: adjust ``top_k`` to ``min(k, number of documents)`` for each query + ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: ``None``, which considers them all) + adaptive_k: Adjust ``top_k`` to ``min(k, number of documents)`` for each query kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -66,7 +65,7 @@ class RetrievalPrecision(RetrievalMetric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``top_k`` is not `None` or an integer larger than 0. + If ``top_k`` is not ``None`` or not an integer greater than 0. ValueError: If ``adaptive_k`` is not boolean. diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index b53a617ecde..26348892242 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -108,7 +108,7 @@ class RetrievalPrecisionRecallCurve(Metric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``max_k`` parameter is not `None` or an integer larger than 0. + If ``max_k`` parameter is not `None` or not an integer larger than 0. Example: >>> from torch import tensor diff --git a/src/torchmetrics/retrieval/r_precision.py b/src/torchmetrics/retrieval/r_precision.py index 1362bb40d40..3d4926473ca 100644 --- a/src/torchmetrics/retrieval/r_precision.py +++ b/src/torchmetrics/retrieval/r_precision.py @@ -38,7 +38,7 @@ class RetrievalRPrecision(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``p2`` (:class:`~torch.Tensor`): A single-value tensor with the r-precision of the predictions ``preds`` + - ``rp`` (:class:`~torch.Tensor`): A single-value tensor with the r-precision of the predictions ``preds`` w.r.t. the labels ``target``. All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, @@ -54,8 +54,7 @@ class RetrievalRPrecision(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` - ignore_index: - Ignore predictions where the target is equal to this number. + ignore_index: Ignore predictions where the target is equal to this number. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: diff --git a/src/torchmetrics/retrieval/recall.py b/src/torchmetrics/retrieval/recall.py index 534d1de6067..a6cfbb51b4c 100644 --- a/src/torchmetrics/retrieval/recall.py +++ b/src/torchmetrics/retrieval/recall.py @@ -38,7 +38,7 @@ class RetrievalRecall(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``r2`` (:class:`~torch.Tensor`): A single-value tensor with the recall (at ``top_k``) of the predictions + - ``r@k`` (:class:`~torch.Tensor`): A single-value tensor with the recall (at ``top_k``) of the predictions ``preds`` w.r.t. the labels ``target`` All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, @@ -55,7 +55,7 @@ class RetrievalRecall(RetrievalMetric): - ``'error'``: raise a ``ValueError`` ignore_index: Ignore predictions where the target is equal to this number. - top_k: consider only the top k elements for each query (default: `None`, which considers them all) + top_k: Consider only the top k elements for each query (default: `None`, which considers them all) kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -64,7 +64,7 @@ class RetrievalRecall(RetrievalMetric): ValueError: If ``ignore_index`` is not `None` or an integer. ValueError: - If ``top_k`` parameter is not `None` or an integer larger than 0. + If ``top_k`` is not ``None`` or not an integer greater than 0. Example: >>> from torch import tensor diff --git a/src/torchmetrics/retrieval/reciprocal_rank.py b/src/torchmetrics/retrieval/reciprocal_rank.py index 9e979178515..b16df38b4bc 100644 --- a/src/torchmetrics/retrieval/reciprocal_rank.py +++ b/src/torchmetrics/retrieval/reciprocal_rank.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union from torch import Tensor @@ -38,8 +38,8 @@ class RetrievalMRR(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``mrr`` (:class:`~torch.Tensor`): A single-value tensor with the reciprocal rank (RR) of the predictions - ``preds`` w.r.t. the labels ``target`` + - ``mrr@k`` (:class:`~torch.Tensor`): A single-value tensor with the reciprocal rank (RR) + of the predictions ``preds`` w.r.t. the labels ``target``. All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by @@ -55,6 +55,7 @@ class RetrievalMRR(RetrievalMetric): - ``'error'``: raise a ``ValueError`` ignore_index: Ignore predictions where the target is equal to this number. + top_k: Consider only the top k elements for each query (default: ``None``, which considers them all) kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -62,6 +63,8 @@ class RetrievalMRR(RetrievalMetric): If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. ValueError: If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``top_k`` is not ``None`` or not an integer greater than 0. Example: >>> from torch import tensor @@ -81,8 +84,25 @@ class RetrievalMRR(RetrievalMetric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 + def __init__( + self, + empty_target_action: str = "neg", + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__( + empty_target_action=empty_target_action, + ignore_index=ignore_index, + **kwargs, + ) + + if top_k is not None and not isinstance(top_k, int) and top_k <= 0: + raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}") + self.top_k = top_k + def _metric(self, preds: Tensor, target: Tensor) -> Tensor: - return retrieval_reciprocal_rank(preds, target) + return retrieval_reciprocal_rank(preds, target, top_k=self.top_k) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/tests/unittests/retrieval/test_mrr.py b/tests/unittests/retrieval/test_mrr.py index f6943d633ca..2c2d566bc38 100644 --- a/tests/unittests/retrieval/test_mrr.py +++ b/tests/unittests/retrieval/test_mrr.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import numpy as np import pytest from sklearn.metrics import label_ranking_average_precision_score @@ -33,7 +35,7 @@ seed_all(42) -def _reciprocal_rank(target: np.ndarray, preds: np.ndarray): +def _reciprocal_rank_at_k(target: np.ndarray, preds: np.ndarray, top_k: Optional[int] = None): """Adaptation of `sklearn.metrics.label_ranking_average_precision_score`. Since the original sklearn metric works as RR only when the number of positive targets is exactly 1, here we remove @@ -44,6 +46,13 @@ def _reciprocal_rank(target: np.ndarray, preds: np.ndarray): assert target.shape == preds.shape assert len(target.shape) == 1 # works only with single dimension inputs + # take k largest predictions here because sklearn does not allow it + if top_k is not None: + top_k = min(top_k, len(preds)) + ind = np.argpartition(preds, -top_k)[-top_k:] + target = target[ind] + preds = preds[ind] + # going to remove T targets that are not ranked as highest indexes = preds[target.astype(bool)] if len(indexes) > 0: @@ -61,6 +70,7 @@ class TestMRR(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail + @pytest.mark.parametrize("top_k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, @@ -70,9 +80,10 @@ def test_class_metric( target: Tensor, empty_target_action: str, ignore_index: int, + top_k: int, ): """Test class implementation of metric.""" - metric_args = {"empty_target_action": empty_target_action, "ignore_index": ignore_index} + metric_args = {"empty_target_action": empty_target_action, "ignore_index": ignore_index, "top_k": top_k} self.run_class_metric_test( ddp=ddp, @@ -80,12 +91,13 @@ def test_class_metric( preds=preds, target=target, metric_class=RetrievalMRR, - reference_metric=_reciprocal_rank, + reference_metric=_reciprocal_rank_at_k, metric_args=metric_args, ) @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("top_k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) def test_class_metric_ignore_index( self, @@ -94,9 +106,10 @@ def test_class_metric_ignore_index( preds: Tensor, target: Tensor, empty_target_action: str, + top_k: int, ): """Test class implementation of metric with ignore_index argument.""" - metric_args = {"empty_target_action": empty_target_action, "ignore_index": -100} + metric_args = {"empty_target_action": empty_target_action, "ignore_index": -100, "top_k": top_k} self.run_class_metric_test( ddp=ddp, @@ -104,19 +117,21 @@ def test_class_metric_ignore_index( preds=preds, target=target, metric_class=RetrievalMRR, - reference_metric=_reciprocal_rank, + reference_metric=_reciprocal_rank_at_k, metric_args=metric_args, ) @pytest.mark.parametrize(**_default_metric_functional_input_arguments) - def test_functional_metric(self, preds: Tensor, target: Tensor): + @pytest.mark.parametrize("top_k", [None, 1, 4, 10]) + def test_functional_metric(self, preds: Tensor, target: Tensor, top_k: int): """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, target=target, metric_functional=retrieval_reciprocal_rank, - reference_metric=_reciprocal_rank, + reference_metric=_reciprocal_rank_at_k, metric_args={}, + top_k=top_k, ) @pytest.mark.parametrize(**_default_metric_class_input_arguments)