Skip to content

Commit b6f6e07

Browse files
authored
Initialize aggregation metrics with default floating type (Lightning-AI#2366)
1 parent 4527aaf commit b6f6e07

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))
3737

3838

39+
- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))
40+
3941
---
4042

4143
## [1.3.0] - 2024-01-10

src/torchmetrics/aggregation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(
157157
) -> None:
158158
super().__init__(
159159
"max",
160-
-torch.tensor(float("inf")),
160+
-torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
161161
nan_strategy,
162162
state_name="max_value",
163163
**kwargs,
@@ -262,7 +262,7 @@ def __init__(
262262
) -> None:
263263
super().__init__(
264264
"min",
265-
torch.tensor(float("inf")),
265+
torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
266266
nan_strategy,
267267
state_name="min_value",
268268
**kwargs,
@@ -366,7 +366,7 @@ def __init__(
366366
) -> None:
367367
super().__init__(
368368
"sum",
369-
torch.tensor(0.0),
369+
torch.tensor(0.0, dtype=torch.get_default_dtype()),
370370
nan_strategy,
371371
state_name="sum_value",
372372
**kwargs,
@@ -536,12 +536,12 @@ def __init__(
536536
) -> None:
537537
super().__init__(
538538
"sum",
539-
torch.tensor(0.0),
539+
torch.tensor(0.0, dtype=torch.get_default_dtype()),
540540
nan_strategy,
541541
state_name="mean_value",
542542
**kwargs,
543543
)
544-
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")
544+
self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")
545545

546546
def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None:
547547
"""Update state with data.

tests/unittests/bases/test_aggregation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,20 @@ def test_mean_metric_broadcast(nan_strategy):
204204
metric.update(x, w)
205205
res = metric.compute()
206206
assert round(res.item(), 4) == 3.2222 # (0*0 + 2*2 + 3*3 + 4*4) / (0 + 2 + 3 + 4)
207+
208+
209+
@pytest.mark.parametrize(
210+
("metric_class", "compare_function"),
211+
[(MinMetric, torch.min), (MaxMetric, torch.max), (SumMetric, torch.sum), (MeanMetric, torch.mean)],
212+
)
213+
def test_with_default_dtype(metric_class, compare_function):
214+
"""Test that the metric works with a default dtype of float64."""
215+
torch.set_default_dtype(torch.float64)
216+
metric = metric_class()
217+
values = torch.randn(10000)
218+
metric.update(values)
219+
result = metric.compute()
220+
assert result.dtype == torch.float64
221+
assert result.dtype == values.dtype
222+
assert torch.allclose(result, compare_function(values), atol=1e-12)
223+
torch.set_default_dtype(torch.float32)

0 commit comments

Comments
 (0)