Skip to content

Commit

Permalink
Docs on memory management (#2006)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Aug 22, 2023
1 parent 465481c commit bc8e016
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 102 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931))


- Added new property `metric_state` to all metrics for users to investigate currently stored tensors in memory ([#2006](https://github.com/Lightning-AI/torchmetrics/pull/2006))

### Changed

-
Expand Down
56 changes: 55 additions & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,62 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.

You can always check which device the metric is located on using the `.device` property.

*****************************
Metrics and memory management
*****************************

As stated before, metrics have states and those states take up a certain amount of memory depending on the metric.
In general metrics can be divided into two categories when we talk about memory management:

* Metrics with tensor states: These metrics only have states that are insteances of :class:`~torch.Tensor`. When these
kind of metrics are updated the values of those tensors are updated. Importantly the size of the tensors are
**constant** meaning that regardless of how much data is passed to the metric, its memory footprint will not change.

* Metrics with list states: These metrics have at least one state that is a list, which gets appended tensors as the
metric is updated. Importantly the size of the list is therefore **not constant** and will grow as the metric is
updated. The growth depends on the particular metric (some metrics only need to store a single value per sample,
some much more).

You can always check the current metric state by accessing the `.metric_state` property, and checking if any of the
states are lists.

.. testcode::

import torch
from torchmetrics.regression import SpearmanCorrCoef

gen = torch.manual_seed(42)
metric = SpearmanCorrCoef()
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)

.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'preds': [tensor([0.8823, 0.9150])], 'target': [tensor([0.3829, 0.9593])]}
{'preds': [tensor([0.8823, 0.9150]), tensor([0.3904, 0.6009])], 'target': [tensor([0.3829, 0.9593]), tensor([0.2566, 0.7936])]}

In general we have a few recommendations for memory management:

* When done with a metric, we always recommend calling the `reset` method. The reason for this being that the python
garbage collector can struggle to totally clean the metric states if this is not done. In the worst case, this can
lead to a memory leak if multiple instances of the same metric for different purposes are created in the same script.

* Better to always try to reuse the same instance of a metric instead of initializing a new one. Calling the `reset` method
returns the metric to its initial state, and can therefore be used to reuse the same instance. However, we still
highly recommend to use **different** instances from training, validation and testing.

* If only the results on a batch level are needed e.g no aggregation or alternatively if you have a small dataset that
fits into iteration of evaluation, we can recommend using the functional API instead as it does not keep an internal
state and memory is therefore freed after each call.

See :ref:`Metric kwargs` for different advanced settings for controlling the memory footprint of metrics.

***********************************************
Metrics in Distributed Data Parallel (DDP) mode
===============================================
***********************************************

When using metrics in `Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_
mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is
Expand Down
6 changes: 6 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Metric(Module, ABC):
"plot_lower_bound",
"plot_upper_bound",
"plot_legend_name",
"metric_state",
]
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None
Expand Down Expand Up @@ -180,6 +181,11 @@ def update_count(self) -> int:
"""Get the number of times `update` and/or `forward` has been called since initialization or last `reset`."""
return self._update_count

@property
def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]:
"""Get the current state of the metric."""
return {attr: getattr(self, attr) for attr in self._defaults}

def add_state(
self,
name: str,
Expand Down
201 changes: 100 additions & 101 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
import torch
from torch import Tensor, tensor
from torch.nn import Module
from torch.nn import Module, Parameter
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.regression import PearsonCorrCoef

Expand Down Expand Up @@ -65,44 +65,46 @@ def test_inherit():

def test_add_state():
"""Test that add state method works as expected."""
a = DummyMetric()
metric = DummyMetric()

a.add_state("a", tensor(0), "sum")
assert a._reductions["a"](tensor([1, 1])) == 2
metric.add_state("a", tensor(0), "sum")
assert metric._reductions["a"](tensor([1, 1])) == 2

a.add_state("b", tensor(0), "mean")
assert np.allclose(a._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5)
metric.add_state("b", tensor(0), "mean")
assert np.allclose(metric._reductions["b"](tensor([1.0, 2.0])).numpy(), 1.5)

a.add_state("c", tensor(0), "cat")
assert a._reductions["c"]([tensor([1]), tensor([1])]).shape == (2,)
metric.add_state("c", tensor(0), "cat")
assert metric._reductions["c"]([tensor([1]), tensor([1])]).shape == (2,)

with pytest.raises(ValueError, match="`dist_reduce_fx` must be callable or one of .*"):
a.add_state("d1", tensor(0), "xyz")
metric.add_state("d1", tensor(0), "xyz")

with pytest.raises(ValueError, match="`dist_reduce_fx` must be callable or one of .*"):
a.add_state("d2", tensor(0), 42)
metric.add_state("d2", tensor(0), 42)

with pytest.raises(ValueError, match="state variable must be a tensor or any empty list .*"):
a.add_state("d3", [tensor(0)], "sum")
metric.add_state("d3", [tensor(0)], "sum")

with pytest.raises(ValueError, match="state variable must be a tensor or any empty list .*"):
a.add_state("d4", 42, "sum")
metric.add_state("d4", 42, "sum")

def custom_fx(_):
return -1

a.add_state("e", tensor(0), custom_fx)
assert a._reductions["e"](tensor([1, 1])) == -1
metric.add_state("e", tensor(0), custom_fx)
assert metric._reductions["e"](tensor([1, 1])) == -1


def test_add_state_persistent():
"""Test that metric states are not added to the normal state dict."""
a = DummyMetric()
metric = DummyMetric()

a.add_state("a", tensor(0), "sum", persistent=True)
assert "a" in a.state_dict()
metric.add_state("a", tensor(0), "sum", persistent=True)
assert "a" in metric.state_dict()

a.add_state("b", tensor(0), "sum", persistent=False)
metric.add_state("b", tensor(0), "sum", persistent=False)
assert "a" in metric.metric_state
assert "b" in metric.metric_state


def test_reset():
Expand All @@ -114,29 +116,31 @@ class A(DummyMetric):
class B(DummyListMetric):
pass

a = A()
assert a.x == 0
a.x = tensor(5)
a.reset()
assert a.x == 0
metric = A()
assert metric.x == 0
metric.x = tensor(5)
metric.reset()
assert metric.x == 0

b = B()
assert isinstance(b.x, list)
assert len(b.x) == 0
b.x = tensor(5)
b.reset()
assert isinstance(b.x, list)
assert len(b.x) == 0
metric = B()
assert isinstance(metric.x, list)
assert len(metric.x) == 0
metric.x = tensor(5)
metric.reset()
assert isinstance(metric.x, list)
assert len(metric.x) == 0


def test_reset_compute():
"""Test that `reset`+`compute` methods works as expected."""
a = DummyMetricSum()
assert a.x == 0
a.update(tensor(5))
assert a.compute() == 5
a.reset()
assert a.compute() == 0
metric = DummyMetricSum()
assert metric.metric_state == {"x": tensor(0)}
metric.update(tensor(5))
assert metric.metric_state == {"x": tensor(5)}
assert metric.compute() == 5
metric.reset()
assert metric.metric_state == {"x": tensor(0)}
assert metric.compute() == 0


def test_update():
Expand All @@ -147,92 +151,74 @@ def update(self, x):
self.x += x

a = A()
assert a.x == 0
assert a.metric_state == {"x": tensor(0)}
assert a._computed is None
a.update(1)
assert a._computed is None
assert a.x == 1
assert a.metric_state == {"x": tensor(1)}
a.update(2)
assert a.x == 3
assert a.metric_state == {"x": tensor(3)}
assert a._computed is None


@pytest.mark.parametrize("compute_with_cache", [True, False])
def test_compute(compute_with_cache):
"""Test that `compute` method works as expected."""

class A(DummyMetric):
def update(self, x):
self.x += x

def compute(self):
return self.x

a = A(compute_with_cache=compute_with_cache)
assert a.compute() == 0
assert a.x == 0
a.update(1)
assert a._computed is None
assert a.compute() == 1
assert a._computed == 1 if compute_with_cache else a._computed is None
a.update(2)
assert a._computed is None
assert a.compute() == 3
assert a._computed == 3 if compute_with_cache else a._computed is None
metric = DummyMetricSum(compute_with_cache=compute_with_cache)
assert metric.compute() == 0
assert metric.metric_state == {"x": tensor(0)}
metric.update(1)
assert metric._computed is None
assert metric.compute() == 1
assert metric._computed == 1 if compute_with_cache else metric._computed is None
assert metric.metric_state == {"x": tensor(1)}
metric.update(2)
assert metric._computed is None
assert metric.compute() == 3
assert metric._computed == 3 if compute_with_cache else metric._computed is None
assert metric.metric_state == {"x": tensor(3)}

# called without update, should return cached value
a._computed = 5
assert a.compute() == 5
metric._computed = 5
assert metric.compute() == 5
assert metric.metric_state == {"x": tensor(3)}


def test_hash():
"""Test that hashes for different metrics are different, even if states are the same."""

class A(DummyMetric):
pass

class B(DummyListMetric):
pass

a1 = A()
a2 = A()
assert hash(a1) != hash(a2)

b1 = B()
b2 = B()
assert hash(b1) != hash(b2) # different ids
assert isinstance(b1.x, list)
assert len(b1.x) == 0
b1.x.append(tensor(5))
assert isinstance(hash(b1), int) # <- check that nothing crashes
assert isinstance(b1.x, list)
assert len(b1.x) == 1
b2.x.append(tensor(5))
metric_1 = DummyMetric()
metric_2 = DummyMetric()
assert hash(metric_1) != hash(metric_2)

metric_1 = DummyListMetric()
metric_2 = DummyListMetric()
assert hash(metric_1) != hash(metric_2) # different ids
assert isinstance(metric_1.x, list)
assert len(metric_1.x) == 0
metric_1.x.append(tensor(5))
assert isinstance(hash(metric_1), int) # <- check that nothing crashes
assert isinstance(metric_1.x, list)
assert len(metric_1.x) == 1
metric_2.x.append(tensor(5))
# Sanity:
assert isinstance(b2.x, list)
assert len(b2.x) == 1
assert isinstance(metric_2.x, list)
assert len(metric_2.x) == 1
# Now that they have tensor contents, they should have different hashes:
assert hash(b1) != hash(b2)
assert hash(metric_1) != hash(metric_2)


def test_forward():
"""Test that `forward` method works as expected."""
metric = DummyMetricSum()
assert metric(5) == 5
assert metric._forward_cache == 5
assert metric.metric_state == {"x": tensor(5)}

class A(DummyMetric):
def update(self, x):
self.x += x

def compute(self):
return self.x

a = A()
assert a(5) == 5
assert a._forward_cache == 5

assert a(8) == 8
assert a._forward_cache == 8
assert metric(8) == 8
assert metric._forward_cache == 8
assert metric.metric_state == {"x": tensor(13)}

assert a.compute() == 13
assert metric.compute() == 13


def test_pickle(tmpdir):
Expand Down Expand Up @@ -275,6 +261,19 @@ def test_load_state_dict(tmpdir):
assert metric.compute() == 5


def test_check_register_not_in_metric_state():
"""Check that calling `register_buffer` or `register_parameter` does not get added to metric state."""

class TempDummyMetric(DummyMetricSum):
def __init__(self) -> None:
super().__init__()
self.register_buffer("buffer", tensor(0, dtype=torch.float))
self.register_parameter("parameter", Parameter(tensor(0, dtype=torch.float)))

metric = TempDummyMetric()
assert metric.metric_state == {"x": tensor(0)}


def test_child_metric_state_dict():
"""Test that child metric states will be added to parent state dict."""

Expand Down Expand Up @@ -356,10 +355,10 @@ def test_warning_on_compute_before_update():
assert val == 2.0


def test_metric_scripts():
@pytest.mark.parametrize("metric_class", [DummyMetric, DummyMetricSum, DummyMetricMultiOutput, DummyListMetric])
def test_metric_scripts(metric_class):
"""Test that metrics are scriptable."""
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())
torch.jit.script(metric_class())


def test_metric_forward_cache_reset():
Expand Down

0 comments on commit bc8e016

Please sign in to comment.