|
19 | 19 | import torch |
20 | 20 | from numpy import array |
21 | 21 | from torch import Tensor, tensor |
22 | | -from torchmetrics.retrieval.base import _retrieval_aggregate |
23 | 22 | from typing_extensions import Literal |
24 | 23 |
|
25 | 24 | from unittests.helpers import seed_all |
|
42 | 41 | # a version of get_group_indexes that depends on NumPy is here to avoid this dependency for the full library |
43 | 42 |
|
44 | 43 |
|
| 44 | +def _retrieval_aggregate( |
| 45 | + values: Tensor, |
| 46 | + aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean", |
| 47 | + dim: Optional[int] = None, |
| 48 | +) -> Tensor: |
| 49 | + """Aggregate the final retrieval values into a single value.""" |
| 50 | + if aggregation == "mean": |
| 51 | + return values.mean() if dim is None else values.mean(dim=dim) |
| 52 | + if aggregation == "median": |
| 53 | + return values.median() if dim is None else values.median(dim=dim).values |
| 54 | + if aggregation == "min": |
| 55 | + return values.min() if dim is None else values.min(dim=dim).values |
| 56 | + if aggregation == "max": |
| 57 | + return values.max() if dim is None else values.max(dim=dim).values |
| 58 | + return aggregation(values, dim=dim) |
| 59 | + |
| 60 | + |
45 | 61 | def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, np.ndarray]]: |
46 | 62 | """Extract group indexes. |
47 | 63 |
|
@@ -74,7 +90,7 @@ def get_group_indexes(indexes: Union[Tensor, np.ndarray]) -> List[Union[Tensor, |
74 | 90 |
|
75 | 91 |
|
76 | 92 | def _custom_aggregate_fn(val: Tensor, dim=None) -> Tensor: |
77 | | - return (val**2).mean(dim=dim) |
| 93 | + return (val**2).mean() if dim is None else (val**2).mean(dim=dim) |
78 | 94 |
|
79 | 95 |
|
80 | 96 | def _compute_sklearn_metric( |
|
0 commit comments