From 06e0e8ceab50bef25f34756179a7ec8f59009d3b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 17 Apr 2023 18:26:06 +0200 Subject: [PATCH] Add plotting n/n (#1682) Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +- docs/source/conf.py | 1 + docs/source/index.rst | 1 + docs/source/pages/plotting.rst | 254 ++++++++++++++++++ docs/source/pyplots/binary_accuracy.py | 15 ++ .../pyplots/binary_accuracy_multistep.py | 24 ++ docs/source/pyplots/collection_binary.py | 29 ++ .../pyplots/collection_binary_together.py | 29 ++ docs/source/pyplots/confusion_matrix.py | 16 ++ docs/source/pyplots/multiclass_accuracy.py | 16 ++ docs/source/pyplots/tracker_binary.py | 44 +++ .../classification/confusion_matrix.py | 35 ++- src/torchmetrics/collections.py | 92 +++++++ src/torchmetrics/utilities/plot.py | 17 +- src/torchmetrics/wrappers/tracker.py | 2 +- tests/unittests/utilities/test_plot.py | 31 +++ 16 files changed, 593 insertions(+), 16 deletions(-) create mode 100644 docs/source/pages/plotting.rst create mode 100644 docs/source/pyplots/binary_accuracy.py create mode 100644 docs/source/pyplots/binary_accuracy_multistep.py create mode 100644 docs/source/pyplots/collection_binary.py create mode 100644 docs/source/pyplots/collection_binary_together.py create mode 100644 docs/source/pyplots/confusion_matrix.py create mode 100644 docs/source/pyplots/multiclass_accuracy.py create mode 100644 docs/source/pyplots/tracker_binary.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 85d854388f7..e73fa061ba9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,7 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1631](https://github.com/Lightning-AI/metrics/pull/1631), [#1650](https://github.com/Lightning-AI/metrics/pull/1650), [#1639](https://github.com/Lightning-AI/metrics/pull/1639), - [#1660](https://github.com/Lightning-AI/metrics/pull/1660) + [#1660](https://github.com/Lightning-AI/metrics/pull/1660), + [#1682](https://github.com/Lightning-AI/torchmetrics/pull/1682), ) diff --git a/docs/source/conf.py b/docs/source/conf.py index dae2a5ba635..1fa459d3692 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -106,6 +106,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # Set that source code from plotting is always included plot_include_source = True +plot_html_show_source_link = True # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/source/index.rst b/docs/source/index.rst index 4566ec93401..9da8cf0a51a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -130,6 +130,7 @@ Or directly from conda pages/quickstart all-metrics pages/overview + pages/plotting pages/implement pages/lightning diff --git a/docs/source/pages/plotting.rst b/docs/source/pages/plotting.rst new file mode 100644 index 00000000000..bc8a6081b7a --- /dev/null +++ b/docs/source/pages/plotting.rst @@ -0,0 +1,254 @@ +.. testsetup:: * + + import torch + import matplotlib + import matplotlib.pyplot as plt + import torchmetrics + +######## +Plotting +######## + +.. note:: + The visualization/plotting interface of Torchmetrics requires ``matplotlib`` to be installed. Install with either + ``pip install matplotlib`` or ``pip install 'torchmetrics[visual]'``. If the latter option is chosen the + `Scienceplot package `_ is also installed and all plots in + Torchmetrics will default to using that style. + +Torchmetrics comes with build-in support for quick visualization of your metrics, by simply using the ``.plot`` method +that all modular metrics implement. This method provides a consistent interface for basic plotting of all metrics. + +.. code-block:: python + + metric = AnyMetricYouLike() + for _ in range(num_updates): + metric.update(preds[i], target[i]) + fig, ax = metric.plot() + +``.plot`` will always return two objects: ``fig`` is an instance of :class:`~matplotlib.figure.Figure` which contains +figure level attributes and `ax` is an instance of :class:`~matplotlib.axes.Axes` that contains all the elements of the +plot. These two objects allow to change attributes of the plot after it is created. For example, if you want to make +the fontsize of the x-axis a bit bigger and give the figure a nice title and finally save it on the above example, it +could be do like this: + +.. code-block:: python + + ax.set_fontsize(fs=20) + fig.set_title("This is a nice plot") + fig.save_fig("my_awesome_plot.png") + +If you want to include a Torchmetrics plot in a bigger figure that has subfigures and subaxes, all ``.plot`` methods +support an optional `ax` argument where you can pass in the subaxes you want the plot to be inserted into: + +.. code-block:: python + + # combine plotting of two metrics into one figure + fig, ax = plt.subplots(nrows=1, ncols=2) + metric1 = Metric1() + metric2 = Metric2() + for _ in range(num_updates): + metric1.update(preds[i], target[i]) + metric2.update(preds[i], target[i]) + metric1.plot(ax=ax[0]) + metric2.plot(ax=ax[1]) + +********************** +Plotting a single step +********************** + +At the most basic level the ``.plot`` method can be used to plot the value from a single step. This can be done in two +ways: +* Either ``.plot`` method is called with no input, and internally ``metric.compute()`` is called and that value is plotted +* ``.plot`` is called on a single returned value by the metric, for example from ``metric.forward()`` + +In both cases it will generate a plot like this (Accuracy as an example): + +.. code-block:: python + + metric = torchmetrics.Accuracy(task="binary") + for _ in range(num_updates): + metric.update(torch.rand(10,), torch.randint(2, (10,))) + fig, ax = metric.plot() + +.. plot:: pyplots/binary_accuracy.py + :scale: 100 + :include-source: false + +A single point plot is not that informative in itself, but if available we will try to include additional information +such as the lower and upper bounds the particular metric can take and if the metric should be minimized or maximized +to be optimal. This is true for all metrics that return a scalar tensor. +Some metrics return multiple values (such as an tensor with multiple elements or a dict of scalar tensors), and in +that case calling ``.plot`` will return a figure similar to this: + +.. code-block:: python + + metric = torchmetrics.Accuracy(task="multiclass", num_classes=3, average=None) + for _ in range(num_updates): + metric.update(torch.randint(3, (10,)), torch.randint(3, (10,))) + fig, ax = metric.plot() + +.. plot:: pyplots/multiclass_accuracy.py + :scale: 100 + :include-source: false + +Here, each element is assumed to be an independent metric and plotted as its own point for comparing. The above is true +for all metrics that return a scalar tensor, but if the metric returns a tensor with multiple elements then the +``.plot`` method will return a specialized plot for that particular metric. Take for example the ``ConfusionMatrix`` +metric: + +.. code-block:: python + + metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3) + for _ in range(num_updates): + metric.update(torch.randint(3, (10,)), torch.randint(3, (10,))) + fig, ax = metric.plot() + +.. plot:: pyplots//confusion_matrix.py + :scale: 100 + :include-source: false + +If you prefer to use the functional interface of Torchmetrics, you can also plot the values returned by the functional. +However, you would still need to initialize the corresponding metric class to get the information about the metric: + +.. code-block:: python + + plot_class = torchmetrics.Accuracy(task="multiclass", num_classes=3) + value = torchmetrics.functional.accuracy( + torch.randint(3, (10,)), torch.randint(3, (10,)), num_classes=3 + ) + fig, ax = plot_class.plot(value) + +******************** +Plotting multi steps +******************** + +In the above examples we have only plotted a single step/single value, but it is also possible to plot multiple steps +from the same metric. This is often the case when training a machine learning model, where you are tracking one or +multiple metrics that you want to plot as they are changing over time. This can be done by providing a sequence of outputs from +any metric, computed using ``metric.forward`` or ``metric.compute``. For example, if we want to plot the accuracy of +a model over time, we could do it like this: + +.. code-block:: python + + metric = torchmetrics.Accuracy(task="binary") + values = [ ] + for step in range(num_steps): + for _ in range(num_updates): + metric.update(preds(step), target(step)) + values.append(metric.compute()) # save value + metric.reset() + fig, ax = metric.plot(values) + +.. plot:: pyplots/binary_accuracy_multistep.py + :scale: 100 + :include-source: false + +Do note that metrics that do not return simple scalar tensors, such as `ConfusionMatrix`, `ROC` that have specialized +visualzation does not support plotting multiple steps, out of the box and the user needs to manually plot the values +for each step. + +******************************** +Plotting a collection of metrics +******************************** + +``MetricCollection`` also supports `.plot` method and by default it works by just returning a collection of plots for +all its members. Thus, instead of returning a single (fig, ax) pair, calling `.plot` method of ``MetricCollection`` will +return a sequence of such pairs, one for each member in the collection. In the following example we are forming a +collection of binary classification metrics and redirecting the output of ``.plot`` to different subplots: + +.. code-block:: python + + collection = torchmetrics.MetricCollection( + torchmetrics.Accuracy(task="binary"), + torchmetrics.Recall(task="binary"), + torchmetrics.Precision(task="binary"), + ) + fig, ax = plt.subplots(nrows=1, ncols=3) + values = [ ] + for step in range(num_steps): + for _ in range(num_updates): + collection.update(preds(step), target(step)) + values.append(collection.compute()) + collection.reset() + collection.plot(val=values, ax=ax) + +.. plot:: pyplots/binary_accuracy_multistep.py + :scale: 100 + :include-source: false + +However, the ``plot`` method of ``MetricCollection`` also supports an additional argument called ``together`` that will +automatically try to plot all the metrics in the collection together in the same plot (with appropriate labels). This +is only possible if all the metrics in the collection return a scalar tensor. + +.. code-block:: python + + collection = torchmetrics.MetricCollection( + torchmetrics.Accuracy(task="binary"), + torchmetrics.Recall(task="binary"), + torchmetrics.Precision(task="binary"), + ) + values = [ ] + fig, ax = plt.subplots(figsize=(6.8, 4.8)) + for step in range(num_steps): + for _ in range(num_updates): + collection.update(preds(step), target(step)) + values.append(collection.compute()) + collection.reset() + collection.plot(val=values, together=True) + +.. plot:: pyplots/collection_binary_together.py + :scale: 100 + :include-source: false + +*************** +Advance example +*************** + +In the following we are going to show how to use the ``.plot`` method to create a more advanced plot. We are going to +combine the functionality of several metrics using ``MetricCollection`` and plot them together. In addition we are going +to rely on ``MetricTracker`` to keep track of the metrics over multiple steps. + +.. code-block:: python + + # Define collection that is a mix of metrics that return a scalar tensors and not + confmat = torchmetrics.ConfusionMatrix(task="binary") + roc = torchmetrics.ROC(task="binary") + collection = torchmetrics.MetricCollection( + torchmetrics.Accuracy(task="binary"), + torchmetrics.Recall(task="binary"), + torchmetrics.Precision(task="binary"), + confmat, + roc, + ) + + # Define tracker over the collection to easy keep track of the metrics over multiple steps + tracker = torchmetrics.wrappers.MetricTracker(collection) + + # Run "training" loop + for step in range(num_steps): + tracker.increment() + for _ in range(N): + tracker.update(preds(step), target(step)) + + # Extract all metrics from all steps + all_results = tracker.compute_all() + + # Constuct a single figure with appropriate layout for all metrics + fig = plt.figure(layout="constrained") + ax1 = plt.subplot(2, 2, 1) + ax2 = plt.subplot(2, 2, 2) + ax3 = plt.subplot(2, 2, (3, 4)) + + # ConfusionMatrix and ROC we just plot the last step, notice how we call the plot method of those metrics + confmat.plot(val=all_results[-1]['BinaryConfusionMatrix'], ax=ax1) + roc.plot(all_results[-1]["BinaryROC"], ax=ax2) + + # For the remainig we plot the full history, but we need to extract the scalar values from the results + scalar_results = [ + {k: v for k, v in ar.items() if isinstance(v, torch.Tensor) and v.numel() == 1} for ar in all_results + ] + tracker.plot(val=scalar_results, ax=ax3) + +.. plot:: pyplots/tracker_binary.py + :scale: 100 + :include-source: false diff --git a/docs/source/pyplots/binary_accuracy.py b/docs/source/pyplots/binary_accuracy.py new file mode 100644 index 00000000000..90539961e4f --- /dev/null +++ b/docs/source/pyplots/binary_accuracy.py @@ -0,0 +1,15 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500) +metric = torchmetrics.Accuracy(task="binary") +for _ in range(N): + metric.update(torch.rand(10), torch.randint(2, (10,))) +metric.plot(ax=ax) +fig.show() diff --git a/docs/source/pyplots/binary_accuracy_multistep.py b/docs/source/pyplots/binary_accuracy_multistep.py new file mode 100644 index 00000000000..c2f29b6e2a4 --- /dev/null +++ b/docs/source/pyplots/binary_accuracy_multistep.py @@ -0,0 +1,24 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +w = torch.tensor([0.2, 0.8]) +target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) +preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) + +fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500) + +metric = torchmetrics.Accuracy(task="binary") +values = [] +for step in range(num_steps): + for _ in range(N): + metric.update(preds(step), target(step)) + values.append(metric.compute()) # save value + metric.reset() +metric.plot(values, ax=ax) +fig.show() diff --git a/docs/source/pyplots/collection_binary.py b/docs/source/pyplots/collection_binary.py new file mode 100644 index 00000000000..9f4b34de74b --- /dev/null +++ b/docs/source/pyplots/collection_binary.py @@ -0,0 +1,29 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +w = torch.tensor([0.2, 0.8]) +target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) +preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) + +collection = torchmetrics.MetricCollection( + torchmetrics.Accuracy(task="binary"), + torchmetrics.Recall(task="binary"), + torchmetrics.Precision(task="binary"), +) + +fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(6.8, 4.8), dpi=500) +values = [] +for step in range(num_steps): + for _ in range(N): + collection.update(preds(step), target(step)) + values.append(collection.compute()) + collection.reset() +collection.plot(val=values, ax=ax) +fig.tight_layout() +fig.show() diff --git a/docs/source/pyplots/collection_binary_together.py b/docs/source/pyplots/collection_binary_together.py new file mode 100644 index 00000000000..f58f104f4e0 --- /dev/null +++ b/docs/source/pyplots/collection_binary_together.py @@ -0,0 +1,29 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +w = torch.tensor([0.2, 0.8]) +target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) +preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) + +collection = torchmetrics.MetricCollection( + torchmetrics.Accuracy(task="binary"), + torchmetrics.Recall(task="binary"), + torchmetrics.Precision(task="binary"), +) + +values = [] +fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500) +for step in range(num_steps): + for _ in range(N): + collection.update(preds(step), target(step)) + values.append(collection.compute()) + collection.reset() +collection.plot(val=values, ax=ax, together=True) +fig.tight_layout() +fig.show() diff --git a/docs/source/pyplots/confusion_matrix.py b/docs/source/pyplots/confusion_matrix.py new file mode 100644 index 00000000000..1ab881bded2 --- /dev/null +++ b/docs/source/pyplots/confusion_matrix.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500) + +metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3) +for _ in range(N): + metric.update(torch.randint(3, (10,)), torch.randint(3, (10,))) +metric.plot(ax=ax) +fig.show() diff --git a/docs/source/pyplots/multiclass_accuracy.py b/docs/source/pyplots/multiclass_accuracy.py new file mode 100644 index 00000000000..ba08c16859c --- /dev/null +++ b/docs/source/pyplots/multiclass_accuracy.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500) + +metric = torchmetrics.Accuracy(task="multiclass", num_classes=3, average=None) +for _ in range(N): + metric.update(torch.randint(3, (10,)), torch.randint(3, (10,))) +metric.plot(ax=ax) +fig.show() diff --git a/docs/source/pyplots/tracker_binary.py b/docs/source/pyplots/tracker_binary.py new file mode 100644 index 00000000000..64f93453417 --- /dev/null +++ b/docs/source/pyplots/tracker_binary.py @@ -0,0 +1,44 @@ +import matplotlib.pyplot as plt +import torch + +import torchmetrics + +N = 10 +num_updates = 10 +num_steps = 5 + +w = torch.tensor([0.2, 0.8]) +target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True) +preds = lambda it: (it * torch.randn(100)).sigmoid() + +confmat = torchmetrics.ConfusionMatrix(task="binary") +roc = torchmetrics.ROC(task="binary") +tracker = torchmetrics.wrappers.MetricTracker( + torchmetrics.MetricCollection( + torchmetrics.Accuracy(task="binary"), + torchmetrics.Recall(task="binary"), + torchmetrics.Precision(task="binary"), + confmat, + roc, + ) +) + +fig = plt.figure(layout="constrained", figsize=(6.8, 4.8), dpi=500) +ax1 = plt.subplot(2, 2, 1) +ax2 = plt.subplot(2, 2, 2) +ax3 = plt.subplot(2, 2, (3, 4)) + +for step in range(num_steps): + tracker.increment() + for _ in range(N): + tracker.update(preds(step), target(step)) + +# get the results from all steps and extract for confusion matrix and roc +all_results = tracker.compute_all() +confmat.plot(val=all_results[-1]["BinaryConfusionMatrix"], ax=ax1) +roc.plot(all_results[-1]["BinaryROC"], ax=ax2) + +scalar_results = [{k: v for k, v in ar.items() if isinstance(v, torch.Tensor) and v.numel() == 1} for ar in all_results] + +tracker.plot(val=scalar_results, ax=ax3) +fig.show() diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 8d50b2b70f4..1496af71418 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -37,7 +37,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE -from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = [ @@ -131,13 +131,18 @@ def compute(self) -> Tensor: return _binary_confusion_matrix_compute(self.confmat, self.normalize) def plot( - self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[List[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis add_text: if the value of each cell should be added to the plot labels: a list of strings, if provided will be added to the plot to indicate the different classes @@ -157,10 +162,10 @@ def plot( >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot() """ - val = val or self.compute() + val = val if val is not None else self.compute() if not isinstance(val, Tensor): raise TypeError(f"Expected val to be a single tensor but got {val}") - fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels) + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels) return fig, ax @@ -269,13 +274,18 @@ def compute(self) -> Tensor: return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) def plot( - self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[List[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis add_text: if the value of each cell should be added to the plot labels: a list of strings, if provided will be added to the plot to indicate the different classes @@ -295,10 +305,10 @@ def plot( >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot() """ - val = val or self.compute() + val = val if val is not None else self.compute() if not isinstance(val, Tensor): raise TypeError(f"Expected val to be a single tensor but got {val}") - fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels) + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels) return fig, ax @@ -393,13 +403,18 @@ def compute(self) -> Tensor: return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) def plot( - self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[List[str]] = None, ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis add_text: if the value of each cell should be added to the plot labels: a list of strings, if provided will be added to the plot to indicate the different classes @@ -419,10 +434,10 @@ def plot( >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot() """ - val = val or self.compute() + val = val if val is not None else self.compute() if not isinstance(val, Tensor): raise TypeError(f"Expected val to be a single tensor but got {val}") - fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels) + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels) return fig, ax diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 6ed50708e28..74a1d6b2863 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -23,6 +23,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import _flatten_dict, allclose +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"] class MetricCollection(ModuleDict): @@ -482,3 +487,90 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection": for _, m in self.items(keep_base=True, copy_state=False): m.set_dtype(dst_type) return self + + def plot( + self, + val: Optional[Union[Dict, Sequence[Dict]]] = None, + ax: Optional[Union[_AX_TYPE, Sequence[_AX_TYPE]]] = None, + together: bool = False, + ) -> Sequence[_PLOT_OUT_TYPE]: + """Plot a single or multiple values from the metric. + + The plot method has two modes of operation. If argument `together` is set to `False` (default), the `.plot` + method of each metric will be called individually and the result will be list of figures. If `together` is set + to `True`, the values of all metrics will instead be plotted in the same figure. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: Either a single instance of matplotlib axis object or an sequence of matplotlib axis objects. If + provided, will add the plots to the provided axis objects. If not provided, will create a new. If + argument `together` is set to `True`, a single object is expected. If `together` is set to `False`, + the number of axis objects needs to be the same lenght as the number of metrics in the collection. + together: If `True`, will plot all metrics in the same axis. If `False`, will plot each metric in a separate + + Returns: + Either instal tupel of Figure and Axes object or an sequence of tuples with Figure and Axes object for each + metric in the collection. + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + ValueError: + If `together` is not an bool + ValueError: + If `ax` is not an instance of matplotlib axis object or a sequence of matplotlib axis objects + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import MetricCollection + >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall + >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()]) + >>> metrics.update(torch.rand(10), torch.randint(2, (10,))) + >>> fig_ax_ = metrics.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import MetricCollection + >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall + >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()]) + >>> values = [] + >>> for _ in range(10): + ... values.append(metrics(torch.rand(10), torch.randint(2, (10,)))) + >>> fig_, ax_ = metrics.plot(values, together=True) + """ + if not isinstance(together, bool): + raise ValueError(f"Expected argument `together` to be a boolean, but got {type(together)}") + if ax is not None: + if together and not isinstance(ax, _AX_TYPE): + raise ValueError( + f"Expected argument `ax` to be a matplotlib axis object, but got {type(ax)} when `together=True`" + ) + if ( + not together + and not isinstance(ax, Sequence) + and not all(isinstance(a, _AX_TYPE) for a in ax) + and len(ax) != len(self) + ): + raise ValueError( + f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the " + f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`" + ) + + val = val or self.compute() + if together: + return plot_single_or_multi_val(val, ax=ax) + fig_axs = [] + for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)): + if isinstance(val, dict): + f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax) + elif isinstance(val, Sequence): + f, a = m.plot([v[k] for v in val], ax=ax[i] if ax is not None else ax) + fig_axs.append((f, a)) + return fig_axs diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index fb40df349f1..5d4841f7a20 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -13,7 +13,7 @@ # limitations under the License. from itertools import product from math import ceil, floor, sqrt -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union, no_type_check import numpy as np import torch @@ -100,7 +100,13 @@ def plot_single_or_multi_val( ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label) elif isinstance(val, dict): for i, (k, v) in enumerate(val.items()): - ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k) + if v.numel() != 1: + ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k) + ax.get_xaxis().set_visible(True) + ax.set_xlabel("Step") + ax.set_xticks(torch.arange(len(v))) + else: + ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k) elif isinstance(val, Sequence): n_steps = len(val) if isinstance(val[0], dict): @@ -167,7 +173,7 @@ def _get_col_row_split(n: int) -> Tuple[int, int]: return ceil(nsq), ceil(nsq) -def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> np.ndarray: # type: ignore[valid-type] +def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> Union[np.ndarray, _AX_TYPE]: # type: ignore[valid-type] """Reduce `axs` to `nb` Axes. All further Axes are removed from the figure. @@ -182,8 +188,10 @@ def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> np.ndarray: # type: @style_change(_style) +@no_type_check def plot_confusion_matrix( confmat: Tensor, + ax: Optional[_AX_TYPE] = None, add_text: bool = True, labels: Optional[List[Union[int, str]]] = None, ) -> _PLOT_OUT_TYPE: @@ -195,6 +203,7 @@ def plot_confusion_matrix( Args: confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an [N, 2, 2] matrix for multilabel classification + ax: Axis from a figure. If not provided, a new figure and axis will be created add_text: if text should be added to each cell with the given value labels: labels to add the x- and y-axis @@ -225,7 +234,7 @@ def plot_confusion_matrix( fig_label = None labels = labels or np.arange(n_classes).tolist() - fig, axs = plt.subplots(nrows=rows, ncols=cols) + fig, axs = plt.subplots(nrows=rows, ncols=cols) if ax is None else (ax.get_figure(), ax) axs = trim_axs(axs, nb) for i in range(nb): ax = axs[i] if rows != 1 and cols != 1 else axs diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index a5b1b3657d7..2fa7273f1bf 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -299,7 +299,7 @@ def plot( >>> fig_, ax_ = tracker.plot() # plot all epochs """ - val = val if val is not None else list(self.compute_all()) + val = val if val is not None else self.compute_all() fig, ax = plot_single_or_multi_val( val, ax=ax, diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 55940fe85d5..98072a6d907 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -21,6 +21,7 @@ import torch from torch import tensor +from torchmetrics import MetricCollection from torchmetrics.aggregation import MaxMetric, MeanMetric, MinMetric, SumMetric from torchmetrics.audio import ( ScaleInvariantSignalDistortionRatio, @@ -775,6 +776,36 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label assert cond1 or cond2 +@pytest.mark.parametrize("together", [True, False]) +@pytest.mark.parametrize("num_vals", [1, 2]) +def test_plot_method_collection(together, num_vals): + """Test the plot method of metric collection.""" + m_collection = MetricCollection( + BinaryAccuracy(), + BinaryPrecision(), + BinaryRecall(), + ) + if num_vals == 1: + m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,))) + fig_ax = m_collection.plot(together=together) + else: + vals = [] + for _ in range(num_vals): + vals.append(m_collection(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))) + fig_ax = m_collection.plot(val=vals, together=together) + + if together: + assert isinstance(fig_ax, tuple) + assert len(fig_ax) == 2 + fig, ax = fig_ax + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + else: + assert isinstance(fig_ax, list) + assert all(isinstance(f[0], plt.Figure) for f in fig_ax) + assert all(isinstance(f[1], matplotlib.axes.Axes) for f in fig_ax) + + @pytest.mark.parametrize( ("metric_class", "preds", "target"), [