Skip to content

Commit dce2368

Browse files
fix ignored custom callable in retrieval metric aggregation (Lightning-AI#2364)
* fix retrieval aggregation * fix retrieval tests * changelog --------- Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
1 parent b6f6e07 commit dce2368

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))
3737

3838

39+
- Fixed custom aggregation in retrieval metrics ([#2364](https://github.com/Lightning-AI/torchmetrics/pull/2364))
40+
41+
3942
- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))
4043

4144
---

src/torchmetrics/retrieval/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _retrieval_aggregate(
3535
return values.median() if dim is None else values.median(dim=dim).values
3636
if aggregation == "min":
3737
return values.min() if dim is None else values.min(dim=dim).values
38-
if aggregation:
38+
if aggregation == "max":
3939
return values.max() if dim is None else values.max(dim=dim).values
4040
return aggregation(values, dim=dim)
4141

tests/unittests/retrieval/helpers.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
from numpy import array
2121
from torch import Tensor, tensor
22-
from torchmetrics.retrieval.base import _retrieval_aggregate
2322
from typing_extensions import Literal
2423

2524
from unittests.helpers import seed_all
@@ -42,6 +41,23 @@
4241
# a version of get_group_indexes that depends on NumPy is here to avoid this dependency for the full library
4342

4443

44+
def _retrieval_aggregate(
45+
values: Tensor,
46+
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
47+
dim: Optional[int] = None,
48+
) -> Tensor:
49+
"""Aggregate the final retrieval values into a single value."""
50+
if aggregation == "mean":
51+
return values.mean() if dim is None else values.mean(dim=dim)
52+
if aggregation == "median":
53+
return values.median() if dim is None else values.median(dim=dim).values
54+
if aggregation == "min":
55+
return values.min() if dim is None else values.min(dim=dim).values
56+
if aggregation == "max":
57+
return values.max() if dim is None else values.max(dim=dim).values
58+
return aggregation(values, dim=dim)
59+
60+
4561
def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]:
4662
"""Extract group indexes.
4763
@@ -74,7 +90,7 @@ def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor,
7490

7591

7692
def _custom_aggregate_fn(val: Tensor, dim=None) -> Tensor:
77-
return (val**2).mean(dim=dim)
93+
return (val**2).mean() if dim is None else (val**2).mean(dim=dim)
7894

7995

8096
def _compute_sklearn_metric(

0 commit comments

Comments
 (0)