Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate argument to retrieval metrics #2220

Merged
merged 18 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for logging `MultiTaskWrapper` directly with lightnings `log_dict` method ([#2213](https://github.com/Lightning-AI/torchmetrics/pull/2213))


- Added `aggregate`` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))

### Changed

- Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089))
Expand Down
14 changes: 13 additions & 1 deletion src/torchmetrics/retrieval/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.retrieval.base import RetrievalMetric
Expand Down Expand Up @@ -56,6 +57,15 @@ class RetrievalMAP(RetrievalMetric):

ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
aggregation:
Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor
and returns a scalar value or one of the following strings:

- ``'mean'``: average value is returned
- ``'median'``: median value is returned
- ``'max'``: max value is returned
- ``'min'``: min value is returned

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -89,11 +99,13 @@ def __init__(
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
aggregation=aggregation,
**kwargs,
)

Expand Down
43 changes: 39 additions & 4 deletions src/torchmetrics/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics import Metric
from torchmetrics.utilities.checks import _check_retrieval_inputs
from torchmetrics.utilities.data import _flexible_bincount, dim_zero_cat


def _retrieval_aggregate(
values: Tensor,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
dim: Optional[int] = None,
) -> Tensor:
"""Aggregate the final retrieval values into a single value."""
if aggregation == "mean":
return values.mean() if dim is None else values.mean(dim=dim)
if aggregation == "median":
return values.median() if dim is None else values.median(dim=dim).values
if aggregation == "min":
return values.min() if dim is None else values.min(dim=dim).values
if aggregation:
return values.max() if dim is None else values.max(dim=dim).values
return aggregation(values, dim=dim)


class RetrievalMetric(Metric, ABC):
"""Works with binary target data. Accepts float predictions from a model output.

Expand Down Expand Up @@ -56,6 +74,15 @@ class RetrievalMetric(Metric, ABC):

ignore_index:
Ignore predictions where the target is equal to this number.
aggregation:
Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor
and returns a scalar value or one of the following strings:

- ``'mean'``: average value is returned
- ``'median'``: median value is returned
- ``'max'``: max value is returned
- ``'min'``: min value is returned

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand All @@ -78,6 +105,7 @@ def __init__(
self,
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -86,14 +114,19 @@ def __init__(
empty_target_action_options = ("error", "skip", "neg", "pos")
if empty_target_action not in empty_target_action_options:
raise ValueError(f"Argument `empty_target_action` received a wrong value `{empty_target_action}`.")

self.empty_target_action = empty_target_action

if ignore_index is not None and not isinstance(ignore_index, int):
raise ValueError("Argument `ignore_index` must be an integer or None.")

self.ignore_index = ignore_index

if not (aggregation in ("mean", "median", "min", "max") or callable(aggregation)):
raise ValueError(
"Argument `aggregation` must be one of `mean`, `median`, `min`, `max` or a custom callable function"
f"which takes tensor of values, but got {aggregation}."
)
self.aggregation = aggregation

self.add_state("indexes", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)
Expand Down Expand Up @@ -144,7 +177,9 @@ def compute(self) -> Tensor:
# ensure list contains only float tensors
res.append(self._metric(mini_preds, mini_target))

return torch.stack([x.to(preds) for x in res]).mean() if res else tensor(0.0).to(preds)
if res:
return _retrieval_aggregate(torch.stack([x.to(preds) for x in res]), self.aggregation)
return tensor(0.0).to(preds)

@abstractmethod
def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
Expand Down
22 changes: 19 additions & 3 deletions src/torchmetrics/retrieval/fall_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.retrieval.base import RetrievalMetric
from torchmetrics.retrieval.base import RetrievalMetric, _retrieval_aggregate
from torchmetrics.utilities.data import _flexible_bincount, dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -57,6 +58,15 @@ class RetrievalFallOut(RetrievalMetric):

ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: `None`, which considers them all)
aggregation:
Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor
and returns a scalar value or one of the following strings:

- ``'mean'``: average value is returned
- ``'median'``: median value is returned
- ``'max'``: max value is returned
- ``'min'``: min value is returned

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -89,11 +99,13 @@ def __init__(
empty_target_action: str = "pos",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
aggregation=aggregation,
**kwargs,
)

Expand Down Expand Up @@ -134,7 +146,11 @@ def compute(self) -> Tensor:
# ensure list contains only float tensors
res.append(self._metric(mini_preds, mini_target))

return torch.stack([x.to(preds) for x in res]).mean() if res else tensor(0.0).to(preds)
return (
_retrieval_aggregate(torch.stack([x.to(preds) for x in res]), aggregation=self.aggregation)
if res
else tensor(0.0).to(preds)
)

def _metric(self, preds: Tensor, target: Tensor) -> Tensor:
return retrieval_fall_out(preds, target, top_k=self.top_k)
Expand Down
14 changes: 13 additions & 1 deletion src/torchmetrics/retrieval/hit_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.retrieval.base import RetrievalMetric
Expand Down Expand Up @@ -57,6 +58,15 @@ class RetrievalHitRate(RetrievalMetric):

ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
aggregation:
Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor
and returns a scalar value or one of the following strings:

- ``'mean'``: average value is returned
- ``'median'``: median value is returned
- ``'max'``: max value is returned
- ``'min'``: min value is returned

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -90,11 +100,13 @@ def __init__(
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
aggregation=aggregation,
**kwargs,
)

Expand Down
14 changes: 13 additions & 1 deletion src/torchmetrics/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.retrieval.base import RetrievalMetric
Expand Down Expand Up @@ -57,6 +58,15 @@ class RetrievalNormalizedDCG(RetrievalMetric):

ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
aggregation:
Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor
and returns a scalar value or one of the following strings:

- ``'mean'``: average value is returned
- ``'median'``: median value is returned
- ``'max'``: max value is returned
- ``'min'``: min value is returned

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -90,11 +100,13 @@ def __init__(
empty_target_action: str = "neg",
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
aggregation=aggregation,
**kwargs,
)

Expand Down
14 changes: 13 additions & 1 deletion src/torchmetrics/retrieval/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.retrieval.base import RetrievalMetric
Expand Down Expand Up @@ -57,6 +58,15 @@ class RetrievalPrecision(RetrievalMetric):
ignore_index: Ignore predictions where the target is equal to this number.
top_k: Consider only the top k elements for each query (default: ``None``, which considers them all)
adaptive_k: Adjust ``top_k`` to ``min(k, number of documents)`` for each query
aggregation:
Specify how to aggregate over indexes. Can either a custom callable function that takes in a single tensor
and returns a scalar value or one of the following strings:

- ``'mean'``: average value is returned
- ``'median'``: median value is returned
- ``'max'``: max value is returned
- ``'min'``: min value is returned

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -93,11 +103,13 @@ def __init__(
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
adaptive_k: bool = False,
aggregation: Union[Literal["mean", "median", "min", "max"], Callable] = "mean",
**kwargs: Any,
) -> None:
super().__init__(
empty_target_action=empty_target_action,
ignore_index=ignore_index,
aggregation=aggregation,
**kwargs,
)

Expand Down
Loading
Loading