Skip to content

Commit

Permalink
Adding MRR@K as an option to MRR (#1961)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
lucadiliello and SkafteNicki authored Aug 3, 2023
1 parent ab01bbd commit ea17ffb
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 47 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added warning to `PearsonCorrCoeff` if input has a very small variance for its given dtype ([#1926](https://github.com/Lightning-AI/torchmetrics/pull/1926))


- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


### Changed

-
Expand Down
15 changes: 13 additions & 2 deletions src/torchmetrics/functional/retrieval/reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor, tensor

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
def retrieval_reciprocal_rank(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Compute reciprocal rank (for information retrieval). See `Mean Reciprocal Rank`_.
``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
Expand All @@ -27,10 +29,15 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
top_k: consider only the top k elements (default: ``None``, which considers them all)
Return:
a single-value tensor with the reciprocal rank (RR) of the predictions ``preds`` wrt the labels ``target``.
Raises:
ValueError:
If ``top_k`` is not ``None`` or an integer larger than 0.
Example:
>>> from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
>>> preds = torch.tensor([0.2, 0.3, 0.5])
Expand All @@ -41,9 +48,13 @@ def retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

top_k = top_k or preds.shape[-1]
if not isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}.")

target = target[preds.topk(min(top_k, preds.shape[-1]), sorted=True, dim=-1)[1]]
if not target.sum():
return tensor(0.0, device=preds.device)

target = target[torch.argsort(preds, dim=-1, descending=True)]
position = torch.nonzero(target).view(-1)
return 1.0 / (position[0] + 1.0)
11 changes: 5 additions & 6 deletions src/torchmetrics/retrieval/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class RetrievalMAP(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``rmap`` (:class:`~torch.Tensor`): A tensor with the mean average precision of the predictions ``preds``
w.r.t. the labels ``target``
- ``map@k`` (:class:`~torch.Tensor`): A single-value tensor with the mean average precision (MAP)
of the predictions ``preds`` w.r.t. the labels ``target``.
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by
Expand All @@ -54,9 +54,8 @@ class RetrievalMAP(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -65,7 +64,7 @@ class RetrievalMAP(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` is not ``None`` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
Example:
>>> from torch import tensor
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/retrieval/fall_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RetrievalFallOut(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``fo`` (:class:`~torch.Tensor`): A tensor with the computed metric
- ``fo@k`` (:class:`~torch.Tensor`): A tensor with the computed metric
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by
Expand All @@ -55,9 +55,8 @@ class RetrievalFallOut(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: `None`, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: `None`, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalFallOut(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
Example:
>>> from torchmetrics.retrieval import RetrievalFallOut
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/retrieval/hit_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalHitRate(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``hr2`` (:class:`~torch.Tensor`): A single-value tensor with the hit rate (at ``top_k``) of the predictions
- ``hr@k`` (:class:`~torch.Tensor`): A single-value tensor with the hit rate (at ``top_k``) of the predictions
``preds`` w.r.t. the labels ``target``
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -55,9 +55,8 @@ class RetrievalHitRate(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalHitRate(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
Example:
>>> from torch import tensor
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalNormalizedDCG(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``ndcg`` (:class:`~torch.Tensor`): A single-value tensor with the nDCG of the predictions
- ``ndcg@k`` (:class:`~torch.Tensor`): A single-value tensor with the nDCG of the predictions
``preds`` w.r.t. the labels ``target``
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -55,9 +55,8 @@ class RetrievalNormalizedDCG(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalNormalizedDCG(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
Example:
>>> from torch import tensor
Expand Down
11 changes: 5 additions & 6 deletions src/torchmetrics/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalPrecision(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``p2`` (:class:`~torch.Tensor`): A single-value tensor with the precision (at ``top_k``) of the predictions
- ``p@k`` (:class:`~torch.Tensor`): A single-value tensor with the precision (at ``top_k``) of the predictions
``preds`` w.r.t. the labels ``target``
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -54,10 +54,9 @@ class RetrievalPrecision(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
ignore_index:
Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: ``None``, which considers them all)
adaptive_k: adjust ``top_k`` to ``min(k, number of documents)`` for each query
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
adaptive_k: Adjust ``top_k`` to ``min(k, number of documents)`` for each query
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -66,7 +65,7 @@ class RetrievalPrecision(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
ValueError:
If ``adaptive_k`` is not boolean.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/retrieval/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class RetrievalPrecisionRecallCurve(Metric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``max_k`` parameter is not `None` or an integer larger than 0.
If ``max_k`` parameter is not `None` or not an integer larger than 0.
Example:
>>> from torch import tensor
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/retrieval/r_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalRPrecision(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``p2`` (:class:`~torch.Tensor`): A single-value tensor with the r-precision of the predictions ``preds``
- ``rp`` (:class:`~torch.Tensor`): A single-value tensor with the r-precision of the predictions ``preds``
w.r.t. the labels ``target``.
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -54,8 +54,7 @@ class RetrievalRPrecision(RetrievalMetric):
- ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
ignore_index:
Ignore predictions where the target is equal to this number.
ignore_index: Ignore predictions where the target is equal to this number.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/retrieval/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RetrievalRecall(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``r2`` (:class:`~torch.Tensor`): A single-value tensor with the recall (at ``top_k``) of the predictions
- ``r@k`` (:class:`~torch.Tensor`): A single-value tensor with the recall (at ``top_k``) of the predictions
``preds`` w.r.t. the labels ``target``
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
Expand All @@ -55,7 +55,7 @@ class RetrievalRecall(RetrievalMetric):
- ``'error'``: raise a ``ValueError``
ignore_index: Ignore predictions where the target is equal to this number.
top_k: consider only the top k elements for each query (default: `None`, which considers them all)
top_k: Consider only the top k elements for each query (default: `None`, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand All @@ -64,7 +64,7 @@ class RetrievalRecall(RetrievalMetric):
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` parameter is not `None` or an integer larger than 0.
If ``top_k`` is not ``None`` or not an integer greater than 0.
Example:
>>> from torch import tensor
Expand Down
28 changes: 24 additions & 4 deletions src/torchmetrics/retrieval/reciprocal_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union

from torch import Tensor

Expand All @@ -38,8 +38,8 @@ class RetrievalMRR(RetrievalMetric):
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mrr`` (:class:`~torch.Tensor`): A single-value tensor with the reciprocal rank (RR) of the predictions
``preds`` w.r.t. the labels ``target``
- ``mrr@k`` (:class:`~torch.Tensor`): A single-value tensor with the reciprocal rank (RR)
of the predictions ``preds`` w.r.t. the labels ``target``.
All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning,
so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by
Expand All @@ -55,13 +55,16 @@ class RetrievalMRR(RetrievalMetric):
- ``'error'``: raise a ``ValueError``
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``.
ValueError:
If ``ignore_index`` is not `None` or an integer.
ValueError:
If ``top_k`` is not ``None`` or not an integer greater than 0.
Example:
>>> from torch import tensor
Expand All @@ -81,8 +84,25 @@ class RetrievalMRR(RetrievalMetric):
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0

def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
**kwargs,
)

if top_k is not None and not isinstance(top_k, int) and top_k <= 0:
raise ValueError(f"Argument ``top_k`` has to be a positive integer or None, but got {top_k}")
self.top_k = top_k

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_reciprocal_rank(preds, target)
return retrieval_reciprocal_rank(preds, target, top_k=self.top_k)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
Loading

0 comments on commit ea17ffb

Please sign in to comment.