Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/image_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored May 25, 2023
2 parents ede6f59 + 962f82d commit c35d0f0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
50 changes: 36 additions & 14 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

.. testsetup:: *

from typing import Optional
from typing import Optional, Sequence, Union
from torch import Tensor

*********************
Implementing a Metric
*********************

To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following methods:
To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following
methods:

- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
- ``update()``: Any code needed to update the state given any inputs to the metric.
Expand All @@ -32,7 +34,7 @@ Example implementation:
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
def update(self, preds: Tensor, target: Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape

Expand All @@ -48,21 +50,41 @@ Additionally you may want to set the class properties: `is_differentiable`, `hig

.. testcode::

from torchmetrics import Metric
from torchmetrics import Metric

class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None

# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True

# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True

class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None
Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the `.plot` method. By default
this method will raise `NotImplementedError` but can be implemented by the user to provide a custom plot for the metric.
For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal `._plot` method can be
used, that provides the common plotting functionality for most metrics in torchmetrics.

.. testcode::

from torchmetrics import Metric
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True
class MyMetric(Metric):
...

# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
return self._plot(val, ax)

If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details
on the plotting API, see the this :ref:`page <plotting>` .

Internal implementation details
-------------------------------
Expand Down
2 changes: 2 additions & 0 deletions docs/source/pages/plotting.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _plotting:

.. testsetup:: *

import torch
Expand Down

0 comments on commit c35d0f0

Please sign in to comment.