Skip to content

Commit

Permalink
Fix meanmetric broadcasting for Nan values (#1898)
Browse files Browse the repository at this point in the history
(cherry picked from commit 393a978)
  • Loading branch information
SkafteNicki authored and Borda committed Jul 13, 2023
1 parent efcc32e commit 39361f3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug related to empty predictions for `IntersectionOverUnion` metric ([#1892](https://github.com/Lightning-AI/torchmetrics/pull/1892))


- Fixed bug related to `MeanMetric` and broadcasting of weights when Nans are present ([#1898](https://github.com/Lightning-AI/torchmetrics/pull/1898))


## [1.0.0] - 2022-07-04

### Added
Expand Down
42 changes: 28 additions & 14 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Sequence, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -69,25 +69,36 @@ def __init__(
self.nan_strategy = nan_strategy
self.add_state("value", default=default_value, dist_reduce_fx=fn)

def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor:
def _cast_and_nan_check_input(
self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
"""Convert input ``x`` to a tensor and check for Nans."""
if not isinstance(x, Tensor):
x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)

nans = torch.isnan(x)
if nans.any():
if weight is not None:
nans_weight = torch.isnan(weight)
else:
nans_weight = torch.zeros_like(nans).bool()
weight = torch.ones_like(x)
if nans.any() or nans_weight.any():
if self.nan_strategy == "error":
raise RuntimeError("Encounted `nan` values in tensor")
if self.nan_strategy in ("ignore", "warn"):
if self.nan_strategy == "warn":
rank_zero_warn("Encounted `nan` values in tensor. Will be removed.", UserWarning)
x = x[~nans]
x = x[~(nans | nans_weight)]
weight = weight[~(nans | nans_weight)]
else:
if not isinstance(self.nan_strategy, float):
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
x[nans] = self.nan_strategy
x[nans | nans_weight] = self.nan_strategy
weight[nans | nans_weight] = self.nan_strategy

return x.float()
return x.float(), weight.float()

def update(self, value: Union[float, Tensor]) -> None:
"""Overwrite in child class."""
Expand Down Expand Up @@ -153,7 +164,7 @@ def update(self, value: Union[float, Tensor]) -> None:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.value = torch.max(self.value, torch.max(value))

Expand Down Expand Up @@ -253,7 +264,7 @@ def update(self, value: Union[float, Tensor]) -> None:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
value, _ = self._cast_and_nan_check_input(value)
if value.numel(): # make sure tensor not empty
self.value = torch.min(self.value, torch.min(value))

Expand Down Expand Up @@ -351,7 +362,7 @@ def update(self, value: Union[float, Tensor]) -> None:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
value, _ = self._cast_and_nan_check_input(value)
if value.numel():
self.value += value.sum()

Expand Down Expand Up @@ -445,7 +456,7 @@ def update(self, value: Union[float, Tensor]) -> None:
value: Either a float or tensor containing data. Additional tensor
dimensions will be flattened
"""
value = self._cast_and_nan_check_input(value)
value, _ = self._cast_and_nan_check_input(value)
if value.numel():
self.value.append(value)

Expand Down Expand Up @@ -516,13 +527,16 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0
the shape of `value`. Default to `1.0` corresponding to simple
harmonic average.
"""
value = self._cast_and_nan_check_input(value)
weight = self._cast_and_nan_check_input(weight)
# broadcast weight to value shape
if not isinstance(value, Tensor):
value = torch.as_tensor(value, dtype=torch.float32, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)
weight = torch.broadcast_to(weight, value.shape)
value, weight = self._cast_and_nan_check_input(value, weight)

if value.numel() == 0:
return
# broadcast weight to value shape
weight = torch.broadcast_to(weight, value.shape)
self.value += (value * weight).sum()
self.weight += weight.sum()

Expand Down
23 changes: 23 additions & 0 deletions tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,26 @@ def test_mean_metric_broadcasting(weights, expected):
avg = MeanMetric()

assert avg(values, weights) == expected


@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to")
@pytest.mark.parametrize("nan_strategy", ["ignore", "warn"])
def test_mean_metric_broadcast(nan_strategy):
"""Check that weights gets broadcasted correctly when Nans are present."""
metric = MeanMetric(nan_strategy=nan_strategy)

x = torch.arange(5).float()
x[1] = torch.tensor(float("nan"))
w = torch.arange(5).float()

metric.update(x, w)
res = metric.compute()
assert round(res.item(), 4) == 3.2222 # (0*0 + 2*2 + 3*3 + 4*4) / (0 + 2 + 3 + 4)

x = torch.arange(5).float()
w = torch.arange(5).float()
w[1] = torch.tensor(float("nan"))

metric.update(x, w)
res = metric.compute()
assert round(res.item(), 4) == 3.2222 # (0*0 + 2*2 + 3*3 + 4*4) / (0 + 2 + 3 + 4)

0 comments on commit 39361f3

Please sign in to comment.