Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix few doc formatting errors #1870

Merged
merged 6 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.

import glob
import inspect
import os
import shutil
import sys

import torch

# this removes "Initializes internal Module state, shared by both nn.Module and ScriptModule." from the docs
torch.nn.Module.__init__.__doc__ = ""

import pt_lightning_sphinx_theme

import torchmetrics


_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
Expand Down
1 change: 1 addition & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ information on this topic.

.. autoclass:: torchmetrics.MetricCollection
:noindex:
:exclude-members: update, compute, forward


****************************
Expand Down
48 changes: 32 additions & 16 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ class name as key for the output dict.
ValueError:
If ``postfix`` is set and it is not a string.

Example (input as list):
Example::
In the most basic case, the metrics can be passed in as a list or tuple. The keys of the output dict will be
the same as the class name of the metric:

>>> from torch import tensor
>>> from pprint import pprint
>>> from torchmetrics import MetricCollection
Expand All @@ -101,7 +104,10 @@ class name as key for the output dict.
'MulticlassPrecision': tensor(0.0667),
'MulticlassRecall': tensor(0.1111)}

Example (input as arguments):
Example::
Alternatively, metrics can be passed in as arguments. The keys of the output dict will be the same as the
class name of the metric:

>>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'),
... MulticlassPrecision(num_classes=3, average='macro'),
... MulticlassRecall(num_classes=3, average='macro'))
Expand All @@ -110,7 +116,10 @@ class name as key for the output dict.
'MulticlassPrecision': tensor(0.0667),
'MulticlassRecall': tensor(0.1111)}

Example (input as dict):
Example::
If multiple of the same metric class (with different parameters) should be chained together, metrics can be
passed in as a dict and the output dict will have the same keys as the input dict:

>>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'),
... 'macro_recall': MulticlassRecall(num_classes=3, average='macro')})
>>> same_metric = metrics.clone()
Expand All @@ -119,20 +128,10 @@ class name as key for the output dict.
>>> pprint(same_metric(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}

Example (specification of compute groups):
>>> metrics = MetricCollection(
... MulticlassRecall(num_classes=3, average='macro'),
... MulticlassPrecision(num_classes=3, average='macro'),
... MeanSquaredError(),
... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']]
... )
>>> metrics.update(preds, target)
>>> pprint(metrics.compute())
{'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
>>> pprint(metrics.compute_groups)
{0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}
Example::
Metric collections can also be nested up to a single time. The output of the collection will still be a single
dict with the prefix and postfix arguments from the nested collection:

Example (nested metric collections):
>>> metrics = MetricCollection([
... MetricCollection([
... MulticlassAccuracy(num_classes=3, average='macro'),
Expand All @@ -148,6 +147,23 @@ class name as key for the output dict.
'valmetrics/MulticlassAccuracy_micro': tensor(0.1250),
'valmetrics/MulticlassPrecision_macro': tensor(0.0667),
'valmetrics/MulticlassPrecision_micro': tensor(0.1250)}

Example::
The `compute_groups` argument allow you to specify which metrics should share metric state. By default, this
will automatically be derived but can also be set manually.

>>> metrics = MetricCollection(
... MulticlassRecall(num_classes=3, average='macro'),
... MulticlassPrecision(num_classes=3, average='macro'),
... MeanSquaredError(),
... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']]
... )
>>> metrics.update(preds, target)
>>> pprint(metrics.compute())
{'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
>>> pprint(metrics.compute_groups)
{0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}

"""

_modules: Dict[str, Metric] # type: ignore[assignment]
Expand Down
69 changes: 45 additions & 24 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import builtins
import functools
import inspect
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -50,33 +51,30 @@ def jit_distributed_available() -> bool:
class Metric(Module, ABC):
"""Base class for all metrics present in the Metrics API.

Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to
handle distributed synchronization and per-step metric computation.
This class is inherited by all metrics and implements the following functionality:
1. Handles the transfer of metric states to correct device
2. Handles the synchronization of metric states across processes

Override ``update()`` and ``compute()`` functions to implement your own metric. Use
``add_state()`` to register metric state variables which keep track of state on each
call of ``update()`` and are synchronized across processes when ``compute()`` is called.
The three core methods of the base class are
* ``add_state()``
* ``forward()``
* ``reset()``

Note:
Metric state variables can either be :class:`~torch.Tensor` or an empty list which can we used
to store :class:`~torch.Tensor`.
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten
* ``update()``
* ``compute()``

Note:
Different metrics only override ``update()`` and not ``forward()``. A call to ``update()``
is valid, but it won't return the metric value at the current step. A call to ``forward()``
automatically calls ``update()`` and also returns the metric value at the current step.

Args:
kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

- compute_on_cpu: If metric state should be stored on CPU during computations. Only works
for list states.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: function that performs the allgather option on the metric state. Default is an
custom implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: function that checks if the distributed backend is available.
Defaults to a check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
- compute_with_cache: If results from ``compute`` should be cached. Default is ``False``
"""
Expand Down Expand Up @@ -191,6 +189,13 @@ def add_state(
) -> None:
"""Add metric state variable. Only used by subclasses.

Metric state variables are either `:class:`~torch.Tensor` or an empty list, which can be appended to by the
metric. Each state variable must have a unique name associated with it. State variables are accessible as
attributes of the metric i.e, if ``name`` is ``"my_state"`` then its value can be accessed from an instance
``metric`` as ``metric.my_state``. Metric states behave like buffers and parameters of :class:`~torch.nn.Module`
as they are also updated when ``.to()`` is called. Unlike parameters and buffers, metric states are not by
default saved in the modules :attr:`~torch.nn.Module.state_dict`.

Args:
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a :class:`~torch.Tensor` or an empty list.
Expand Down Expand Up @@ -225,7 +230,8 @@ def add_state(
ValueError:
If ``default`` is not a ``tensor`` or an ``empty list``.
ValueError:
If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``.
If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``"min"``,
``"max"`` or ``None``.
"""
if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default):
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
Expand Down Expand Up @@ -259,6 +265,17 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch
statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
``update`` method. The returned output is the exact same as the output of ``compute``.

Args:
args: Any arguments as required by the metric ``update`` method.
kwargs: Any keyword arguments as required by the metric ``update`` method.

Returns:
The output of the ``compute`` method evaluated on the current batch.

Raises:
TorchMetricsUserError:
If the metric is already synced and ``forward`` is called again.
"""
# check if states are already synced
if self._is_synced:
Expand Down Expand Up @@ -467,6 +484,10 @@ def sync(
should_sync: Whether to apply to state synchronization. This will have an impact
only when running in a distributed setting.
distributed_available: Function to determine if we are running inside a distributed setting

Raises:
TorchMetricsUserError:
If the metric is already synced and ``sync`` is called again.
"""
if self._is_synced and should_sync:
raise TorchMetricsUserError("The Metric has already been synced.")
Expand Down Expand Up @@ -679,28 +700,28 @@ def device(self) -> "torch.device":
def type(self, dst_type: Union[str, torch.dtype]) -> "Metric": # noqa: A003
"""Override default and prevent dtype casting.

Please use `metric.set_dtype(dtype)` instead.
Please use :meth:`Metric.set_dtype` instead.
"""
return self

def float(self) -> "Metric": # noqa: A003
"""Override default and prevent dtype casting.

Please use `metric.set_dtype(dtype)` instead.
Please use :meth:`Metric.set_dtype` instead.
"""
return self

def double(self) -> "Metric":
"""Override default and prevent dtype casting.

Please use `metric.set_dtype(dtype)` instead.
Please use :meth:`Metric.set_dtype` instead.
"""
return self

def half(self) -> "Metric":
"""Override default and prevent dtype casting.

Please use `metric.set_dtype(dtype)` instead.
Please use :meth:`Metric.set_dtype` instead.
"""
return self

Expand All @@ -716,7 +737,7 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":
return out

def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device.
"""Overwrite `_apply` function such that we can also move metric states to the correct device.

This method is called by the base ``nn.Module`` class whenever `.to`, `.cuda`, `.float`, `.half` etc. methods
are called. Dtype conversion is garded and will only happen through the special `set_dtype` method.
Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ def plot_single_or_multi_val(
xlim = ax.get_xlim()
factor = 0.1 * (xlim[1] - xlim[0])

y_ = [lower_bound, upper_bound] if lower_bound and upper_bound else []
ax.hlines(y_, xlim[0], xlim[1], linestyles="dashed", colors="k")
y_lines = []
if lower_bound is not None:
y_lines.append(lower_bound)
if upper_bound is not None:
y_lines.append(upper_bound)
ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k")
if higher_is_better is not None:
if lower_bound is not None and not higher_is_better:
ax.set_xlim(xlim[0] - factor, xlim[1])
Expand Down