Skip to content

Commit c35d0f0

Browse files
authored
Merge branch 'master' into bugfix/image_metrics
2 parents ede6f59 + 962f82d commit c35d0f0

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

docs/source/pages/implement.rst

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
.. testsetup:: *
44

5-
from typing import Optional
5+
from typing import Optional, Sequence, Union
6+
from torch import Tensor
67

78
*********************
89
Implementing a Metric
910
*********************
1011

11-
To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following methods:
12+
To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following
13+
methods:
1214

1315
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
1416
- ``update()``: Any code needed to update the state given any inputs to the metric.
@@ -32,7 +34,7 @@ Example implementation:
3234
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
3335
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
3436

35-
def update(self, preds: torch.Tensor, target: torch.Tensor):
37+
def update(self, preds: Tensor, target: Tensor):
3638
preds, target = self._input_format(preds, target)
3739
assert preds.shape == target.shape
3840

@@ -48,21 +50,41 @@ Additionally you may want to set the class properties: `is_differentiable`, `hig
4850

4951
.. testcode::
5052

51-
from torchmetrics import Metric
53+
from torchmetrics import Metric
54+
55+
class MyMetric(Metric):
56+
# Set to True if the metric is differentiable else set to False
57+
is_differentiable: Optional[bool] = None
58+
59+
# Set to True if the metric reaches it optimal value when the metric is maximized.
60+
# Set to False if it when the metric is minimized.
61+
higher_is_better: Optional[bool] = True
62+
63+
# Set to True if the metric during 'update' requires access to the global metric
64+
# state for its calculations. If not, setting this to False indicates that all
65+
# batch states are independent and we will optimize the runtime of 'forward'
66+
full_state_update: bool = True
5267

53-
class MyMetric(Metric):
54-
# Set to True if the metric is differentiable else set to False
55-
is_differentiable: Optional[bool] = None
68+
Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the `.plot` method. By default
69+
this method will raise `NotImplementedError` but can be implemented by the user to provide a custom plot for the metric.
70+
For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal `._plot` method can be
71+
used, that provides the common plotting functionality for most metrics in torchmetrics.
72+
73+
.. testcode::
74+
75+
from torchmetrics import Metric
76+
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
5677

57-
# Set to True if the metric reaches it optimal value when the metric is maximized.
58-
# Set to False if it when the metric is minimized.
59-
higher_is_better: Optional[bool] = True
78+
class MyMetric(Metric):
79+
...
6080

61-
# Set to True if the metric during 'update' requires access to the global metric
62-
# state for its calculations. If not, setting this to False indicates that all
63-
# batch states are independent and we will optimize the runtime of 'forward'
64-
full_state_update: bool = True
81+
def plot(
82+
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
83+
) -> _PLOT_OUT_TYPE:
84+
return self._plot(val, ax)
6585

86+
If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details
87+
on the plotting API, see the this :ref:`page <plotting>` .
6688

6789
Internal implementation details
6890
-------------------------------

docs/source/pages/plotting.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _plotting:
2+
13
.. testsetup:: *
24

35
import torch

0 commit comments

Comments
 (0)