From bc8e016bfcac45ccd58caf4d4b4ba2d67fdacdd2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Aug 2023 10:53:07 +0200 Subject: [PATCH] Docs on memory management (#2006) Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- CHANGELOG.md | 2 + docs/source/pages/overview.rst | 56 +++++++- src/torchmetrics/metric.py | 6 + tests/unittests/bases/test_metric.py | 201 +++++++++++++-------------- 4 files changed, 163 insertions(+), 102 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c7770a90667..1d4c262ac67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931)) +- Added new property `metric_state` to all metrics for users to investigate currently stored tensors in memory ([#2006](https://github.com/Lightning-AI/torchmetrics/pull/2006)) + ### Changed - diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index a3c32ae34e4..6fa6aeb256b 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -130,8 +130,62 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics. You can always check which device the metric is located on using the `.device` property. +***************************** +Metrics and memory management +***************************** + +As stated before, metrics have states and those states take up a certain amount of memory depending on the metric. +In general metrics can be divided into two categories when we talk about memory management: + +* Metrics with tensor states: These metrics only have states that are insteances of :class:`~torch.Tensor`. When these + kind of metrics are updated the values of those tensors are updated. Importantly the size of the tensors are + **constant** meaning that regardless of how much data is passed to the metric, its memory footprint will not change. + +* Metrics with list states: These metrics have at least one state that is a list, which gets appended tensors as the + metric is updated. Importantly the size of the list is therefore **not constant** and will grow as the metric is + updated. The growth depends on the particular metric (some metrics only need to store a single value per sample, + some much more). + +You can always check the current metric state by accessing the `.metric_state` property, and checking if any of the +states are lists. + +.. testcode:: + + import torch + from torchmetrics.regression import SpearmanCorrCoef + + gen = torch.manual_seed(42) + metric = SpearmanCorrCoef() + metric(torch.rand(2,), torch.rand(2,)) + print(metric.metric_state) + metric(torch.rand(2,), torch.rand(2,)) + print(metric.metric_state) + +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + {'preds': [tensor([0.8823, 0.9150])], 'target': [tensor([0.3829, 0.9593])]} + {'preds': [tensor([0.8823, 0.9150]), tensor([0.3904, 0.6009])], 'target': [tensor([0.3829, 0.9593]), tensor([0.2566, 0.7936])]} + +In general we have a few recommendations for memory management: + +* When done with a metric, we always recommend calling the `reset` method. The reason for this being that the python + garbage collector can struggle to totally clean the metric states if this is not done. In the worst case, this can + lead to a memory leak if multiple instances of the same metric for different purposes are created in the same script. + +* Better to always try to reuse the same instance of a metric instead of initializing a new one. Calling the `reset` method + returns the metric to its initial state, and can therefore be used to reuse the same instance. However, we still + highly recommend to use **different** instances from training, validation and testing. + +* If only the results on a batch level are needed e.g no aggregation or alternatively if you have a small dataset that + fits into iteration of evaluation, we can recommend using the functional API instead as it does not keep an internal + state and memory is therefore freed after each call. + +See :ref:`Metric kwargs` for different advanced settings for controlling the memory footprint of metrics. + +*********************************************** Metrics in Distributed Data Parallel (DDP) mode -=============================================== +*********************************************** When using metrics in `Distributed Data Parallel (DDP) `_ mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 8819d9618f2..ddf386ebe6c 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -86,6 +86,7 @@ class Metric(Module, ABC): "plot_lower_bound", "plot_upper_bound", "plot_legend_name", + "metric_state", ] is_differentiable: Optional[bool] = None higher_is_better: Optional[bool] = None @@ -180,6 +181,11 @@ def update_count(self) -> int: """Get the number of times `update` and/or `forward` has been called since initialization or last `reset`.""" return self._update_count + @property + def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]: + """Get the current state of the metric.""" + return {attr: getattr(self, attr) for attr in self._defaults} + def add_state( self, name: str, diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 62eddb93141..5e6a1c2e979 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -23,7 +23,7 @@ import pytest import torch from torch import Tensor, tensor -from torch.nn import Module +from torch.nn import Module, Parameter from torchmetrics.classification import BinaryAccuracy from torchmetrics.regression import PearsonCorrCoef @@ -65,44 +65,46 @@ def test_inherit(): def test_add_state(): """Test that add state method works as expected.""" - a = DummyMetric() + metric = DummyMetric() - a.add_state("a", tensor(0), "sum") - assert a._reductions["a"](tensor([1, 1])) == 2 + metric.add_state("a", tensor(0), "sum") + assert metric._reductions["a"](tensor([1, 1])) == 2 - a.add_state("b", tensor(0), "mean") - assert np.allclose(a._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5) + metric.add_state("b", tensor(0), "mean") + assert np.allclose(metric._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5) - a.add_state("c", tensor(0), "cat") - assert a._reductions["c"]([tensor([1]), tensor([1])]).shape == (2,) + metric.add_state("c", tensor(0), "cat") + assert metric._reductions["c"]([tensor([1]), tensor([1])]).shape == (2,) with pytest.raises(ValueError, match="`dist_reduce_fx` must be callable or one of .*"): - a.add_state("d1", tensor(0), "xyz") + metric.add_state("d1", tensor(0), "xyz") with pytest.raises(ValueError, match="`dist_reduce_fx` must be callable or one of .*"): - a.add_state("d2", tensor(0), 42) + metric.add_state("d2", tensor(0), 42) with pytest.raises(ValueError, match="state variable must be a tensor or any empty list .*"): - a.add_state("d3", [tensor(0)], "sum") + metric.add_state("d3", [tensor(0)], "sum") with pytest.raises(ValueError, match="state variable must be a tensor or any empty list .*"): - a.add_state("d4", 42, "sum") + metric.add_state("d4", 42, "sum") def custom_fx(_): return -1 - a.add_state("e", tensor(0), custom_fx) - assert a._reductions["e"](tensor([1, 1])) == -1 + metric.add_state("e", tensor(0), custom_fx) + assert metric._reductions["e"](tensor([1, 1])) == -1 def test_add_state_persistent(): """Test that metric states are not added to the normal state dict.""" - a = DummyMetric() + metric = DummyMetric() - a.add_state("a", tensor(0), "sum", persistent=True) - assert "a" in a.state_dict() + metric.add_state("a", tensor(0), "sum", persistent=True) + assert "a" in metric.state_dict() - a.add_state("b", tensor(0), "sum", persistent=False) + metric.add_state("b", tensor(0), "sum", persistent=False) + assert "a" in metric.metric_state + assert "b" in metric.metric_state def test_reset(): @@ -114,29 +116,31 @@ class A(DummyMetric): class B(DummyListMetric): pass - a = A() - assert a.x == 0 - a.x = tensor(5) - a.reset() - assert a.x == 0 + metric = A() + assert metric.x == 0 + metric.x = tensor(5) + metric.reset() + assert metric.x == 0 - b = B() - assert isinstance(b.x, list) - assert len(b.x) == 0 - b.x = tensor(5) - b.reset() - assert isinstance(b.x, list) - assert len(b.x) == 0 + metric = B() + assert isinstance(metric.x, list) + assert len(metric.x) == 0 + metric.x = tensor(5) + metric.reset() + assert isinstance(metric.x, list) + assert len(metric.x) == 0 def test_reset_compute(): """Test that `reset`+`compute` methods works as expected.""" - a = DummyMetricSum() - assert a.x == 0 - a.update(tensor(5)) - assert a.compute() == 5 - a.reset() - assert a.compute() == 0 + metric = DummyMetricSum() + assert metric.metric_state == {"x": tensor(0)} + metric.update(tensor(5)) + assert metric.metric_state == {"x": tensor(5)} + assert metric.compute() == 5 + metric.reset() + assert metric.metric_state == {"x": tensor(0)} + assert metric.compute() == 0 def test_update(): @@ -147,92 +151,74 @@ def update(self, x): self.x += x a = A() - assert a.x == 0 + assert a.metric_state == {"x": tensor(0)} assert a._computed is None a.update(1) assert a._computed is None - assert a.x == 1 + assert a.metric_state == {"x": tensor(1)} a.update(2) - assert a.x == 3 + assert a.metric_state == {"x": tensor(3)} assert a._computed is None @pytest.mark.parametrize("compute_with_cache", [True, False]) def test_compute(compute_with_cache): """Test that `compute` method works as expected.""" - - class A(DummyMetric): - def update(self, x): - self.x += x - - def compute(self): - return self.x - - a = A(compute_with_cache=compute_with_cache) - assert a.compute() == 0 - assert a.x == 0 - a.update(1) - assert a._computed is None - assert a.compute() == 1 - assert a._computed == 1 if compute_with_cache else a._computed is None - a.update(2) - assert a._computed is None - assert a.compute() == 3 - assert a._computed == 3 if compute_with_cache else a._computed is None + metric = DummyMetricSum(compute_with_cache=compute_with_cache) + assert metric.compute() == 0 + assert metric.metric_state == {"x": tensor(0)} + metric.update(1) + assert metric._computed is None + assert metric.compute() == 1 + assert metric._computed == 1 if compute_with_cache else metric._computed is None + assert metric.metric_state == {"x": tensor(1)} + metric.update(2) + assert metric._computed is None + assert metric.compute() == 3 + assert metric._computed == 3 if compute_with_cache else metric._computed is None + assert metric.metric_state == {"x": tensor(3)} # called without update, should return cached value - a._computed = 5 - assert a.compute() == 5 + metric._computed = 5 + assert metric.compute() == 5 + assert metric.metric_state == {"x": tensor(3)} def test_hash(): """Test that hashes for different metrics are different, even if states are the same.""" - - class A(DummyMetric): - pass - - class B(DummyListMetric): - pass - - a1 = A() - a2 = A() - assert hash(a1) != hash(a2) - - b1 = B() - b2 = B() - assert hash(b1) != hash(b2) # different ids - assert isinstance(b1.x, list) - assert len(b1.x) == 0 - b1.x.append(tensor(5)) - assert isinstance(hash(b1), int) # <- check that nothing crashes - assert isinstance(b1.x, list) - assert len(b1.x) == 1 - b2.x.append(tensor(5)) + metric_1 = DummyMetric() + metric_2 = DummyMetric() + assert hash(metric_1) != hash(metric_2) + + metric_1 = DummyListMetric() + metric_2 = DummyListMetric() + assert hash(metric_1) != hash(metric_2) # different ids + assert isinstance(metric_1.x, list) + assert len(metric_1.x) == 0 + metric_1.x.append(tensor(5)) + assert isinstance(hash(metric_1), int) # <- check that nothing crashes + assert isinstance(metric_1.x, list) + assert len(metric_1.x) == 1 + metric_2.x.append(tensor(5)) # Sanity: - assert isinstance(b2.x, list) - assert len(b2.x) == 1 + assert isinstance(metric_2.x, list) + assert len(metric_2.x) == 1 # Now that they have tensor contents, they should have different hashes: - assert hash(b1) != hash(b2) + assert hash(metric_1) != hash(metric_2) def test_forward(): """Test that `forward` method works as expected.""" + metric = DummyMetricSum() + assert metric(5) == 5 + assert metric._forward_cache == 5 + assert metric.metric_state == {"x": tensor(5)} - class A(DummyMetric): - def update(self, x): - self.x += x - - def compute(self): - return self.x - - a = A() - assert a(5) == 5 - assert a._forward_cache == 5 - - assert a(8) == 8 - assert a._forward_cache == 8 + assert metric(8) == 8 + assert metric._forward_cache == 8 + assert metric.metric_state == {"x": tensor(13)} - assert a.compute() == 13 + assert metric.compute() == 13 def test_pickle(tmpdir): @@ -275,6 +261,19 @@ def test_load_state_dict(tmpdir): assert metric.compute() == 5 +def test_check_register_not_in_metric_state(): + """Check that calling `register_buffer` or `register_parameter` does not get added to metric state.""" + + class TempDummyMetric(DummyMetricSum): + def __init__(self) -> None: + super().__init__() + self.register_buffer("buffer", tensor(0, dtype=torch.float)) + self.register_parameter("parameter", Parameter(tensor(0, dtype=torch.float))) + + metric = TempDummyMetric() + assert metric.metric_state == {"x": tensor(0)} + + def test_child_metric_state_dict(): """Test that child metric states will be added to parent state dict.""" @@ -356,10 +355,10 @@ def test_warning_on_compute_before_update(): assert val == 2.0 -def test_metric_scripts(): +@pytest.mark.parametrize("metric_class", [DummyMetric, DummyMetricSum, DummyMetricMultiOutput, DummyListMetric]) +def test_metric_scripts(metric_class): """Test that metrics are scriptable.""" - torch.jit.script(DummyMetric()) - torch.jit.script(DummyMetricSum()) + torch.jit.script(metric_class()) def test_metric_forward_cache_reset():