Skip to content

Commit

Permalink
Add plotting 15/n (#1638)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 21, 2023
1 parent 9055b99 commit b613c50
Show file tree
Hide file tree
Showing 7 changed files with 543 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1621](https://github.com/Lightning-AI/metrics/pull/1621),
[#1624](https://github.com/Lightning-AI/metrics/pull/1624),
[#1623](https://github.com/Lightning-AI/metrics/pull/1623),
[#1638](https://github.com/Lightning-AI/metrics/pull/1638),
[#1631](https://github.com/Lightning-AI/metrics/pull/1631),
)

Expand Down
48 changes: 47 additions & 1 deletion src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple, no_type_check
from typing import Any, Callable, Optional, Sequence, Tuple, Union, no_type_check

import torch
from torch import Tensor
Expand All @@ -21,6 +21,11 @@
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["Dice.plot"]


class Dice(Metric):
Expand Down Expand Up @@ -235,3 +240,44 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, _, fn = self._get_final_stats()
return _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torch import randint
>>> from torchmetrics.classification import Dice
>>> metric = Dice()
>>> metric.update(randint(2,(10,)), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import randint
>>> from torchmetrics.classification import Dice
>>> metric = Dice()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(randint(2,(10,)), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
89 changes: 88 additions & 1 deletion src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor
Expand All @@ -33,6 +33,11 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTaskNoBinary
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MulticlassExactMatch.plot", "MultilabelExactMatch.plot"]


class MulticlassExactMatch(Metric):
Expand Down Expand Up @@ -140,6 +145,47 @@ def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassExactMatch
>>> metric = MulticlassExactMatch(num_classes=3)
>>> metric.update(randint(3, (20,5)), randint(3, (20,5)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassExactMatch
>>> metric = MulticlassExactMatch(num_classes=3)
>>> values = []
>>> for _ in range(20):
... values.append(metric(randint(3, (20,5)), randint(3, (20,5))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MultilabelExactMatch(Metric):
r"""Compute Exact match (also known as subset accuracy) for multilabel tasks.
Expand Down Expand Up @@ -261,6 +307,47 @@ def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelExactMatch
>>> metric = MultilabelExactMatch(num_labels=3)
>>> metric.update(randint(2, (20, 3, 5)), randint(2, (20, 3, 5)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelExactMatch
>>> metric = MultilabelExactMatch(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(randint(2, (20, 3, 5)), randint(2, (20, 3, 5))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class ExactMatch:
r"""Compute Exact match (also known as subset accuracy).
Expand Down
134 changes: 133 additions & 1 deletion src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -20,6 +20,15 @@
from torchmetrics.functional.classification.hamming import _hamming_distance_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
"BinaryHammingDistance.plot",
"MulticlassHammingDistance.plot",
"MultilabelHammingDistance.plot",
]


class BinaryHammingDistance(BinaryStatScores):
Expand Down Expand Up @@ -98,6 +107,47 @@ def compute(self) -> Tensor:
tp, fp, tn, fn = self._final_state()
return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryHammingDistance
>>> metric = BinaryHammingDistance()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryHammingDistance
>>> metric = BinaryHammingDistance()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MulticlassHammingDistance(MulticlassStatScores):
r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks.
Expand Down Expand Up @@ -204,6 +254,47 @@ def compute(self) -> Tensor:
tp, fp, tn, fn = self._final_state()
return _hamming_distance_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> metric = MulticlassHammingDistance(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting a multiple values per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> metric = MulticlassHammingDistance(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
... values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class MultilabelHammingDistance(MultilabelStatScores):
r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks.
Expand Down Expand Up @@ -310,6 +401,47 @@ def compute(self) -> Tensor:
tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> metric = MultilabelHammingDistance(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> metric = MultilabelHammingDistance(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class HammingDistance:
r"""Compute the average `Hamming distance`_ (also known as Hamming loss).
Expand Down
Loading

0 comments on commit b613c50

Please sign in to comment.