From 962f82db7b0ab7ca32fe050aba45f1d3f9fe64c9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 25 May 2023 14:25:15 +0200 Subject: [PATCH] Docs: Small plotting note for custom implementations (#1807) * docs * docs * docs * Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --------- Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- docs/source/pages/implement.rst | 50 ++++++++++++++++++++++++--------- docs/source/pages/plotting.rst | 2 ++ 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 7f04dc94474..89f958117ac 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -2,13 +2,15 @@ .. testsetup:: * - from typing import Optional + from typing import Optional, Sequence, Union + from torch import Tensor ********************* Implementing a Metric ********************* -To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following methods: +To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following +methods: - ``__init__()``: Each state variable should be called using ``self.add_state(...)``. - ``update()``: Any code needed to update the state given any inputs to the metric. @@ -32,7 +34,7 @@ Example implementation: self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def update(self, preds: torch.Tensor, target: torch.Tensor): + def update(self, preds: Tensor, target: Tensor): preds, target = self._input_format(preds, target) assert preds.shape == target.shape @@ -48,21 +50,41 @@ Additionally you may want to set the class properties: `is_differentiable`, `hig .. testcode:: - from torchmetrics import Metric + from torchmetrics import Metric + + class MyMetric(Metric): + # Set to True if the metric is differentiable else set to False + is_differentiable: Optional[bool] = None + + # Set to True if the metric reaches it optimal value when the metric is maximized. + # Set to False if it when the metric is minimized. + higher_is_better: Optional[bool] = True + + # Set to True if the metric during 'update' requires access to the global metric + # state for its calculations. If not, setting this to False indicates that all + # batch states are independent and we will optimize the runtime of 'forward' + full_state_update: bool = True - class MyMetric(Metric): - # Set to True if the metric is differentiable else set to False - is_differentiable: Optional[bool] = None +Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the `.plot` method. By default +this method will raise `NotImplementedError` but can be implemented by the user to provide a custom plot for the metric. +For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal `._plot` method can be +used, that provides the common plotting functionality for most metrics in torchmetrics. + +.. testcode:: + + from torchmetrics import Metric + from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE - # Set to True if the metric reaches it optimal value when the metric is maximized. - # Set to False if it when the metric is minimized. - higher_is_better: Optional[bool] = True + class MyMetric(Metric): + ... - # Set to True if the metric during 'update' requires access to the global metric - # state for its calculations. If not, setting this to False indicates that all - # batch states are independent and we will optimize the runtime of 'forward' - full_state_update: bool = True + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + return self._plot(val, ax) +If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details +on the plotting API, see the this :ref:`page ` . Internal implementation details ------------------------------- diff --git a/docs/source/pages/plotting.rst b/docs/source/pages/plotting.rst index bc8a6081b7a..a35410923f5 100644 --- a/docs/source/pages/plotting.rst +++ b/docs/source/pages/plotting.rst @@ -1,3 +1,5 @@ +.. _plotting: + .. testsetup:: * import torch