Skip to content

Commit

Permalink
Convert classification wrapper to metrics (#1963)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and mergify[bot] authored Aug 3, 2023
1 parent ea17ffb commit f6b5890
Show file tree
Hide file tree
Showing 67 changed files with 620 additions and 87 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Changed all non-task specific classification metrics to be true subtypes of `Metric` ([#1963](https://github.com/Lightning-AI/torchmetrics/pull/1963))


### Removed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/classification/accuracy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ________________
.. autoclass:: torchmetrics.Accuracy
:noindex:
:exclude-members: update, compute

:special-members: __new__

BinaryAccuracy
^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/auroc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ________________
.. autoclass:: torchmetrics.AUROC
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryAUROC
^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/average_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.AveragePrecision
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/calibration_error.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ________________
.. autoclass:: torchmetrics.CalibrationError
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryCalibrationError
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/cohen_kappa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Cohen Kappa
Module Interface
________________

CohenKappa
^^^^^^^^^^

.. autoclass:: torchmetrics.CohenKappa
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryCohenKappa
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/confusion_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Confusion Matrix
Module Interface
________________

ConfusionMatrix
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ConfusionMatrix
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryConfusionMatrix
^^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/exact_match.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ Exact Match
Module Interface
________________

ExactMatch
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ExactMatch
:noindex:
:exclude-members: update, compute
:special-members: __new__

MulticlassExactMatch
^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/f1_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ F-1 Score
Module Interface
________________

F1Score
^^^^^^^

.. autoclass:: torchmetrics.F1Score
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryF1Score
^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/fbeta_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ F-Beta Score
Module Interface
________________

FBetaScore
^^^^^^^^^^

.. autoclass:: torchmetrics.FBetaScore
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryFBetaScore
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/hamming_distance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ Hamming Distance
Module Interface
________________

HammingDistance
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.HammingDistance
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryHammingDistance
^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/hinge_loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.HingeLoss
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryHingeLoss
^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/jaccard_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ Jaccard Index
Module Interface
________________

JaccardIndex
^^^^^^^^^^^^

.. autoclass:: torchmetrics.JaccardIndex
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryJaccardIndex
^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/matthews_corr_coef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Matthews Correlation Coefficient
Module Interface
________________

MatthewsCorrCoef
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MatthewsCorrCoef
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryMatthewsCorrCoef
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ________________
.. autoclass:: torchmetrics.Precision
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryPrecision
^^^^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classification/precision_at_fixed_recall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Precision At Fixed Recall
Module Interface
________________

.. autoclass:: torchmetrics.PrecisionAtFixedRecall
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryPrecisionAtFixedRecall
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/precision_recall_curve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.PrecisionRecallCurve
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/recall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.Recall
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryRecall
^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classification/recall_at_fixed_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Recall At Fixed Precision
Module Interface
________________

.. autoclass:: torchmetrics.RecallAtFixedPrecision
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/roc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.ROC
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryROC
^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/specificity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.Specificity
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinarySpecificity
^^^^^^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classification/specificity_at_sensitivity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Specificity At Sensitivity
Module Interface
________________

.. autoclass:: torchmetrics.SpecificityAtSensitivity
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinarySpecificityAtSensitivity
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/stat_scores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Stat Scores
Module Interface
________________

StatScores
^^^^^^^^^^

.. autoclass:: torchmetrics.StatScores
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryStatScores
^^^^^^^^^^^^^^^^
Expand Down
7 changes: 1 addition & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
import shutil
import sys

import torch

# this removes "Initializes internal Module state, shared by both nn.Module and ScriptModule." from the docs
torch.nn.Module.__init__.__doc__ = ""

import pt_lightning_sphinx_theme
from lightning_utilities.docs import fetch_external_assets
from lightning_utilities.docs.formatting import _transform_changelog
Expand Down Expand Up @@ -391,7 +386,7 @@ def _get_version_str():

autodoc_member_order = "groupwise"

autoclass_content = "both"
autoclass_content = "class"

autodoc_default_options = {
"members": True,
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Recall,
RecallAtFixedPrecision,
Specificity,
SpecificityAtSensitivity,
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
Expand Down Expand Up @@ -223,6 +224,7 @@
"SignalNoiseRatio",
"SpearmanCorrCoef",
"Specificity",
"SpecificityAtSensitivity",
"SpectralAngleMapper",
"SpectralDistortionIndex",
"SQuAD",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
BinarySpecificityAtSensitivity,
MulticlassSpecificityAtSensitivity,
MultilabelSpecificityAtSensitivity,
SpecificityAtSensitivity,
)
from torchmetrics.classification.stat_scores import (
BinaryStatScores,
Expand Down Expand Up @@ -201,6 +202,7 @@
"MulticlassSpecificityAtSensitivity",
"MultilabelSpecificityAtSensitivity",
"BinaryPrecisionAtFixedRecall",
"SpecificityAtSensitivity",
"MulticlassPrecisionAtFixedRecall",
"MultilabelPrecisionAtFixedRecall",
"PrecisionAtFixedRecall",
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.accuracy import _accuracy_reduce
from torchmetrics.metric import Metric
Expand Down Expand Up @@ -445,7 +446,7 @@ def plot(
return self._plot(val, ax)


class Accuracy:
class Accuracy(_ClassificationTaskWrapper):
r"""Compute `Accuracy`_.
.. math::
Expand Down
15 changes: 14 additions & 1 deletion src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.precision_recall_curve import (
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
Expand Down Expand Up @@ -460,7 +461,7 @@ def plot( # type: ignore[override]
return self._plot(val, ax)


class AUROC:
class AUROC(_ClassificationTaskWrapper):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_).
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
Expand Down Expand Up @@ -518,3 +519,15 @@ def __new__( # type: ignore[misc]
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelAUROC(num_labels, average, **kwargs)
raise ValueError(f"Task {task} not supported!")

def update(self, *args: Any, **kwargs: Any) -> None:
"""Update metric state."""
raise NotImplementedError(
f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric."
)

def compute(self) -> None:
"""Compute metric."""
raise NotImplementedError(
f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric."
)
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.precision_recall_curve import (
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
Expand Down Expand Up @@ -465,7 +466,7 @@ def plot( # type: ignore[override]
return self._plot(val, ax)


class AveragePrecision:
class AveragePrecision(_ClassificationTaskWrapper):
r"""Compute the average precision (AP) score.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
Expand Down
Loading

0 comments on commit f6b5890

Please sign in to comment.