diff --git a/docs/source/conf.py b/docs/source/conf.py index 3c6517bdf2e..034aa2ab114 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,17 +10,22 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. - import glob import inspect import os import shutil import sys +import torch + +# this removes "Initializes internal Module state, shared by both nn.Module and ScriptModule." from the docs +torch.nn.Module.__init__.__doc__ = "" + import pt_lightning_sphinx_theme import torchmetrics + _PATH_HERE = os.path.abspath(os.path.dirname(__file__)) _PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", "..")) sys.path.insert(0, os.path.abspath(_PATH_ROOT)) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index a220fb92191..14f071b72f9 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -306,6 +306,7 @@ information on this topic. .. autoclass:: torchmetrics.MetricCollection :noindex: + :exclude-members: update, compute, forward **************************** diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 787f22d8b77..27d3cb207b9 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -85,7 +85,10 @@ class name as key for the output dict. ValueError: If ``postfix`` is set and it is not a string. - Example (input as list): + Example:: + In the most basic case, the metrics can be passed in as a list or tuple. The keys of the output dict will be + the same as the class name of the metric: + >>> from torch import tensor >>> from pprint import pprint >>> from torchmetrics import MetricCollection @@ -101,7 +104,10 @@ class name as key for the output dict. 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)} - Example (input as arguments): + Example:: + Alternatively, metrics can be passed in as arguments. The keys of the output dict will be the same as the + class name of the metric: + >>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'), ... MulticlassPrecision(num_classes=3, average='macro'), ... MulticlassRecall(num_classes=3, average='macro')) @@ -110,7 +116,10 @@ class name as key for the output dict. 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)} - Example (input as dict): + Example:: + If multiple of the same metric class (with different parameters) should be chained together, metrics can be + passed in as a dict and the output dict will have the same keys as the input dict: + >>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'), ... 'macro_recall': MulticlassRecall(num_classes=3, average='macro')}) >>> same_metric = metrics.clone() @@ -119,20 +128,10 @@ class name as key for the output dict. >>> pprint(same_metric(preds, target)) {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} - Example (specification of compute groups): - >>> metrics = MetricCollection( - ... MulticlassRecall(num_classes=3, average='macro'), - ... MulticlassPrecision(num_classes=3, average='macro'), - ... MeanSquaredError(), - ... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']] - ... ) - >>> metrics.update(preds, target) - >>> pprint(metrics.compute()) - {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)} - >>> pprint(metrics.compute_groups) - {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']} + Example:: + Metric collections can also be nested up to a single time. The output of the collection will still be a single + dict with the prefix and postfix arguments from the nested collection: - Example (nested metric collections): >>> metrics = MetricCollection([ ... MetricCollection([ ... MulticlassAccuracy(num_classes=3, average='macro'), @@ -148,6 +147,23 @@ class name as key for the output dict. 'valmetrics/MulticlassAccuracy_micro': tensor(0.1250), 'valmetrics/MulticlassPrecision_macro': tensor(0.0667), 'valmetrics/MulticlassPrecision_micro': tensor(0.1250)} + + Example:: + The `compute_groups` argument allow you to specify which metrics should share metric state. By default, this + will automatically be derived but can also be set manually. + + >>> metrics = MetricCollection( + ... MulticlassRecall(num_classes=3, average='macro'), + ... MulticlassPrecision(num_classes=3, average='macro'), + ... MeanSquaredError(), + ... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']] + ... ) + >>> metrics.update(preds, target) + >>> pprint(metrics.compute()) + {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)} + >>> pprint(metrics.compute_groups) + {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']} + """ _modules: Dict[str, Metric] # type: ignore[assignment] diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 19d6f64c206..eea115b0215 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -17,6 +17,7 @@ import builtins import functools import inspect +import os from abc import ABC, abstractmethod from contextlib import contextmanager from copy import deepcopy @@ -50,33 +51,30 @@ def jit_distributed_available() -> bool: class Metric(Module, ABC): """Base class for all metrics present in the Metrics API. - Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to - handle distributed synchronization and per-step metric computation. + This class is inherited by all metrics and implements the following functionality: + 1. Handles the transfer of metric states to correct device + 2. Handles the synchronization of metric states across processes - Override ``update()`` and ``compute()`` functions to implement your own metric. Use - ``add_state()`` to register metric state variables which keep track of state on each - call of ``update()`` and are synchronized across processes when ``compute()`` is called. + The three core methods of the base class are + * ``add_state()`` + * ``forward()`` + * ``reset()`` - Note: - Metric state variables can either be :class:`~torch.Tensor` or an empty list which can we used - to store :class:`~torch.Tensor`. + which should almost never be overwritten by child classes. Instead, the following methods should be overwritten + * ``update()`` + * ``compute()`` - Note: - Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` - is valid, but it won't return the metric value at the current step. A call to ``forward()`` - automatically calls ``update()`` and also returns the metric value at the current step. Args: kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info. - - compute_on_cpu: If metric state should be stored on CPU during computations. Only works - for list states. + - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states. - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False`` - process_group: The process group on which the synchronization is called. Default is the world. - - dist_sync_fn: function that performs the allgather option on the metric state. Default is an - custom implementation that calls ``torch.distributed.all_gather`` internally. - - distributed_available_fn: function that checks if the distributed backend is available. - Defaults to a check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``. + - dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom + implementation that calls ``torch.distributed.all_gather`` internally. + - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a + check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``. - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True`` - compute_with_cache: If results from ``compute`` should be cached. Default is ``False`` """ @@ -191,6 +189,13 @@ def add_state( ) -> None: """Add metric state variable. Only used by subclasses. + Metric state variables are either `:class:`~torch.Tensor` or an empty list, which can be appended to by the + metric. Each state variable must have a unique name associated with it. State variables are accessible as + attributes of the metric i.e, if ``name`` is ``"my_state"`` then its value can be accessed from an instance + ``metric`` as ``metric.my_state``. Metric states behave like buffers and parameters of :class:`~torch.nn.Module` + as they are also updated when ``.to()`` is called. Unlike parameters and buffers, metric states are not by + default saved in the modules :attr:`~torch.nn.Module.state_dict`. + Args: name: The name of the state variable. The variable will then be accessible at ``self.name``. default: Default value of the state; can either be a :class:`~torch.Tensor` or an empty list. @@ -225,7 +230,8 @@ def add_state( ValueError: If ``default`` is not a ``tensor`` or an ``empty list``. ValueError: - If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``. + If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``"min"``, + ``"max"`` or ``None``. """ if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default): raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") @@ -259,6 +265,17 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as the output of ``compute``. + + Args: + args: Any arguments as required by the metric ``update`` method. + kwargs: Any keyword arguments as required by the metric ``update`` method. + + Returns: + The output of the ``compute`` method evaluated on the current batch. + + Raises: + TorchMetricsUserError: + If the metric is already synced and ``forward`` is called again. """ # check if states are already synced if self._is_synced: @@ -467,6 +484,10 @@ def sync( should_sync: Whether to apply to state synchronization. This will have an impact only when running in a distributed setting. distributed_available: Function to determine if we are running inside a distributed setting + + Raises: + TorchMetricsUserError: + If the metric is already synced and ``sync`` is called again. """ if self._is_synced and should_sync: raise TorchMetricsUserError("The Metric has already been synced.") @@ -679,28 +700,28 @@ def device(self) -> "torch.device": def type(self, dst_type: Union[str, torch.dtype]) -> "Metric": # noqa: A003 """Override default and prevent dtype casting. - Please use `metric.set_dtype(dtype)` instead. + Please use :meth:`Metric.set_dtype` instead. """ return self def float(self) -> "Metric": # noqa: A003 """Override default and prevent dtype casting. - Please use `metric.set_dtype(dtype)` instead. + Please use :meth:`Metric.set_dtype` instead. """ return self def double(self) -> "Metric": """Override default and prevent dtype casting. - Please use `metric.set_dtype(dtype)` instead. + Please use :meth:`Metric.set_dtype` instead. """ return self def half(self) -> "Metric": """Override default and prevent dtype casting. - Please use `metric.set_dtype(dtype)` instead. + Please use :meth:`Metric.set_dtype` instead. """ return self @@ -716,7 +737,7 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric": return out def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module: - """Overwrite _apply function such that we can also move metric states to the correct device. + """Overwrite `_apply` function such that we can also move metric states to the correct device. This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods are called. Dtype conversion is garded and will only happen through the special `set_dtype` method. diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 43487cd1a2c..cfceadba9bb 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -147,8 +147,12 @@ def plot_single_or_multi_val( xlim = ax.get_xlim() factor = 0.1 * (xlim[1] - xlim[0]) - y_ = [lower_bound, upper_bound] if lower_bound and upper_bound else [] - ax.hlines(y_, xlim[0], xlim[1], linestyles="dashed", colors="k") + y_lines = [] + if lower_bound is not None: + y_lines.append(lower_bound) + if upper_bound is not None: + y_lines.append(upper_bound) + ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k") if higher_is_better is not None: if lower_bound is not None and not higher_is_better: ax.set_xlim(xlim[0] - factor, xlim[1])