2
2
3
3
.. testsetup :: *
4
4
5
- from typing import Optional
5
+ from typing import Optional, Sequence, Union
6
+ from torch import Tensor
6
7
7
8
*********************
8
9
Implementing a Metric
9
10
*********************
10
11
11
- To implement your own custom metric, subclass the base :class: `~torchmetrics.Metric ` class and implement the following methods:
12
+ To implement your own custom metric, subclass the base :class: `~torchmetrics.Metric ` class and implement the following
13
+ methods:
12
14
13
15
- ``__init__() ``: Each state variable should be called using ``self.add_state(...) ``.
14
16
- ``update() ``: Any code needed to update the state given any inputs to the metric.
@@ -32,7 +34,7 @@ Example implementation:
32
34
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
33
35
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
34
36
35
- def update(self, preds: torch. Tensor, target: torch. Tensor):
37
+ def update(self, preds: Tensor, target: Tensor):
36
38
preds, target = self._input_format(preds, target)
37
39
assert preds.shape == target.shape
38
40
@@ -48,21 +50,41 @@ Additionally you may want to set the class properties: `is_differentiable`, `hig
48
50
49
51
.. testcode ::
50
52
51
- from torchmetrics import Metric
53
+ from torchmetrics import Metric
54
+
55
+ class MyMetric(Metric):
56
+ # Set to True if the metric is differentiable else set to False
57
+ is_differentiable: Optional[bool] = None
58
+
59
+ # Set to True if the metric reaches it optimal value when the metric is maximized.
60
+ # Set to False if it when the metric is minimized.
61
+ higher_is_better: Optional[bool] = True
62
+
63
+ # Set to True if the metric during 'update' requires access to the global metric
64
+ # state for its calculations. If not, setting this to False indicates that all
65
+ # batch states are independent and we will optimize the runtime of 'forward'
66
+ full_state_update: bool = True
52
67
53
- class MyMetric(Metric):
54
- # Set to True if the metric is differentiable else set to False
55
- is_differentiable: Optional[bool] = None
68
+ Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the `.plot ` method. By default
69
+ this method will raise `NotImplementedError ` but can be implemented by the user to provide a custom plot for the metric.
70
+ For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal `._plot ` method can be
71
+ used, that provides the common plotting functionality for most metrics in torchmetrics.
72
+
73
+ .. testcode ::
74
+
75
+ from torchmetrics import Metric
76
+ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
56
77
57
- # Set to True if the metric reaches it optimal value when the metric is maximized.
58
- # Set to False if it when the metric is minimized.
59
- higher_is_better: Optional[bool] = True
78
+ class MyMetric(Metric):
79
+ ...
60
80
61
- # Set to True if the metric during 'update' requires access to the global metric
62
- # state for its calculations. If not, setting this to False indicates that all
63
- # batch states are independent and we will optimize the runtime of 'forward'
64
- full_state_update: bool = True
81
+ def plot(
82
+ self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
83
+ ) -> _PLOT_OUT_TYPE:
84
+ return self._plot(val, ax)
65
85
86
+ If the metric returns a more complex output, a custom implementation of the `plot ` method is required. For more details
87
+ on the plotting API, see the this :ref: `page <plotting >` .
66
88
67
89
Internal implementation details
68
90
-------------------------------
0 commit comments