Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for using metric collection and aggregation metric #1896

Merged
merged 10 commits into from
Jul 12, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixes corner case when using `MetricCollection` together with aggregation metrics ([#1896](https://github.com/Lightning-AI/torchmetrics/pull/1896))


- Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895))


Expand Down
30 changes: 22 additions & 8 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ class BaseAggregator(Metric):
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value

state_name: name of the metric state
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
"""

value: Tensor
is_differentiable = None
higher_is_better = None
full_state_update: bool = False
Expand All @@ -56,6 +56,7 @@ def __init__(
fn: Union[Callable, str],
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
state_name: str = "value",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -67,7 +68,8 @@ def __init__(
)

self.nan_strategy = nan_strategy
self.add_state("value", default=default_value, dist_reduce_fx=fn)
self.add_state(state_name, default=default_value, dist_reduce_fx=fn)
self.state_name = state_name

def _cast_and_nan_check_input(
self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None
Expand Down Expand Up @@ -105,7 +107,7 @@ def update(self, value: Union[float, Tensor]) -> None:

def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value
return getattr(self, self.state_name)


class MaxMetric(BaseAggregator):
Expand Down Expand Up @@ -144,6 +146,7 @@ class MaxMetric(BaseAggregator):
"""

full_state_update: bool = True
max_value: Tensor

def __init__(
self,
Expand All @@ -154,6 +157,7 @@ def __init__(
"max",
-torch.tensor(float("inf")),
nan_strategy,
state_name="max_value",
**kwargs,
)

Expand All @@ -166,7 +170,7 @@ def update(self, value: Union[float, Tensor]) -> None:
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.value = torch.max(self.value, torch.max(value))
self.max_value = torch.max(self.max_value, torch.max(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -244,6 +248,7 @@ class MinMetric(BaseAggregator):
"""

full_state_update: bool = True
min_value: Tensor

def __init__(
self,
Expand All @@ -254,6 +259,7 @@ def __init__(
"min",
torch.tensor(float("inf")),
nan_strategy,
state_name="min_value",
**kwargs,
)

Expand All @@ -266,7 +272,7 @@ def update(self, value: Union[float, Tensor]) -> None:
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.value = torch.min(self.value, torch.min(value))
self.min_value = torch.min(self.min_value, torch.min(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -343,6 +349,8 @@ class SumMetric(BaseAggregator):
tensor(6.)
"""

sum_value: Tensor

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand All @@ -352,6 +360,7 @@ def __init__(
"sum",
torch.tensor(0.0),
nan_strategy,
state_name="sum_value",
**kwargs,
)

Expand All @@ -364,7 +373,7 @@ def update(self, value: Union[float, Tensor]) -> None:
"""
value, _ = self._cast_and_nan_check_input(value)
if value.numel():
self.value += value.sum()
self.sum_value += value.sum()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -442,6 +451,8 @@ class CatMetric(BaseAggregator):
tensor([1., 2., 3.])
"""

value: Tensor

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand Down Expand Up @@ -503,6 +514,8 @@ class MeanMetric(BaseAggregator):
tensor(2.)
"""

mean_value: Tensor

def __init__(
self,
nan_strategy: Union[str, float] = "warn",
Expand All @@ -512,6 +525,7 @@ def __init__(
"sum",
torch.tensor(0.0),
nan_strategy,
state_name="mean_value",
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")
Expand All @@ -537,12 +551,12 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0

if value.numel() == 0:
return
self.value += (value * weight).sum()
self.mean_value += (value * weight).sum()
self.weight += weight.sum()

def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value / self.weight
return self.mean_value / self.weight

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
16 changes: 8 additions & 8 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,22 +419,22 @@ def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = BoringModel()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32
model = model.half()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32

model = BoringModel()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32
model = model.double()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32

model = BoringModel(metric_dtype=torch.float16)
assert model.metric.value.dtype == torch.float16
assert model.metric.sum_value.dtype == torch.float16
model = model.float()
assert model.metric.value.dtype == torch.float16
assert model.metric.sum_value.dtype == torch.float16

model = BoringModel()
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32

model = model.type(torch.half)
assert model.metric.value.dtype == torch.float32
assert model.metric.sum_value.dtype == torch.float32
17 changes: 17 additions & 0 deletions tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import torch
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric
from torchmetrics.collections import MetricCollection

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers.testers import MetricTester
Expand Down Expand Up @@ -166,6 +167,22 @@ def test_mean_metric_broadcasting(weights, expected):
assert avg(values, weights) == expected


def test_aggregation_in_collection_with_compute_groups():
"""Check that aggregation metrics work in MetricCollection with compute_groups=True."""
m = MetricCollection(MinMetric(), MaxMetric(), SumMetric(), MeanMetric(), compute_groups=True)
assert len(m.compute_groups) == 4, "Expected 4 compute groups"
m.update(1)
assert len(m.compute_groups) == 4, "Expected 4 compute groups"
m.update(2)
assert len(m.compute_groups) == 4, "Expected 4 compute groups"

res = m.compute()
assert res["MinMetric"] == 1
assert res["MaxMetric"] == 2
assert res["SumMetric"] == 3
assert res["MeanMetric"] == 1.5


@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to")
@pytest.mark.parametrize("nan_strategy", ["ignore", "warn"])
def test_mean_metric_broadcast(nan_strategy):
Expand Down