Skip to content

Commit

Permalink
Revert
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 12, 2022
1 parent b026818 commit 826c937
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def check_forward_full_state_property(
... ConfusionMatrix,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... ) # doctest: +SKIP
... ) # doctest: +ELLIPSIS
Full state for 10 steps took: ...
Partial state for 10 steps took: ...
Full state for 100 steps took: ...
Expand All @@ -672,7 +672,7 @@ def check_forward_full_state_property(
... MyMetric,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... ) # doctest: +SKIP
... )
Recommended setting `full_state_update=True`
"""

Expand Down
16 changes: 15 additions & 1 deletion tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics import MeanSquaredError, PearsonCorrCoef
from torchmetrics.utilities import check_forward_full_state_property, rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.checks import _allclose_recursive
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce
Expand Down Expand Up @@ -129,6 +130,19 @@ def test_bincount():
assert torch.allclose(res1, res3)


@pytest.mark.parametrize("metric_class, expected", [(MeanSquaredError, False), (PearsonCorrCoef, True)])
def test_check_full_state_update_fn(capsys, metric_class, expected):
"""Test that the check function works as it should."""
check_forward_full_state_property(
metric_class=metric_class,
input_args=dict(preds=torch.randn(1000), target=torch.randn(1000)),
num_update_to_compare=[10000],
reps=5,
)
captured = capsys.readouterr()
assert f"Recommended setting `full_state_update={expected}`" in captured.out


@pytest.mark.parametrize(
"input, expected",
[
Expand Down

0 comments on commit 826c937

Please sign in to comment.