Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert classification wrapper to metrics #1963

Merged
merged 11 commits into from
Aug 3, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Changed all non-task specific classification metrics to be true subtypes of `Metric` ([#1963](https://github.com/Lightning-AI/torchmetrics/pull/1963))


### Removed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/classification/accuracy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ________________
.. autoclass:: torchmetrics.Accuracy
:noindex:
:exclude-members: update, compute

:special-members: __new__

BinaryAccuracy
^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/auroc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ________________
.. autoclass:: torchmetrics.AUROC
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryAUROC
^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/average_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.AveragePrecision
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/calibration_error.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ________________
.. autoclass:: torchmetrics.CalibrationError
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryCalibrationError
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/cohen_kappa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Cohen Kappa
Module Interface
________________

CohenKappa
^^^^^^^^^^

.. autoclass:: torchmetrics.CohenKappa
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryCohenKappa
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/confusion_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Confusion Matrix
Module Interface
________________

ConfusionMatrix
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ConfusionMatrix
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryConfusionMatrix
^^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/exact_match.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ Exact Match
Module Interface
________________

ExactMatch
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ExactMatch
:noindex:
:exclude-members: update, compute
:special-members: __new__

MulticlassExactMatch
^^^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/f1_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ F-1 Score
Module Interface
________________

F1Score
^^^^^^^

.. autoclass:: torchmetrics.F1Score
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryF1Score
^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/fbeta_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ F-Beta Score
Module Interface
________________

FBetaScore
^^^^^^^^^^

.. autoclass:: torchmetrics.FBetaScore
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryFBetaScore
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/hamming_distance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ Hamming Distance
Module Interface
________________

HammingDistance
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.HammingDistance
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryHammingDistance
^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/hinge_loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.HingeLoss
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryHingeLoss
^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/jaccard_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ Jaccard Index
Module Interface
________________

JaccardIndex
^^^^^^^^^^^^

.. autoclass:: torchmetrics.JaccardIndex
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryJaccardIndex
^^^^^^^^^^^^^^^^^^
Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/matthews_corr_coef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Matthews Correlation Coefficient
Module Interface
________________

MatthewsCorrCoef
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MatthewsCorrCoef
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryMatthewsCorrCoef
^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ________________
.. autoclass:: torchmetrics.Precision
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryPrecision
^^^^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classification/precision_at_fixed_recall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Precision At Fixed Recall
Module Interface
________________

.. autoclass:: torchmetrics.PrecisionAtFixedRecall
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryPrecisionAtFixedRecall
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/precision_recall_curve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.PrecisionRecallCurve
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/recall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.Recall
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryRecall
^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classification/recall_at_fixed_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Recall At Fixed Precision
Module Interface
________________

.. autoclass:: torchmetrics.RecallAtFixedPrecision
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/roc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.ROC
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryROC
^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/classification/specificity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ________________
.. autoclass:: torchmetrics.Specificity
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinarySpecificity
^^^^^^^^^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions docs/source/classification/specificity_at_sensitivity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Specificity At Sensitivity
Module Interface
________________

.. autoclass:: torchmetrics.SpecificityAtSensitivity
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinarySpecificityAtSensitivity
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 1 addition & 3 deletions docs/source/classification/stat_scores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ Stat Scores
Module Interface
________________

StatScores
^^^^^^^^^^

.. autoclass:: torchmetrics.StatScores
:noindex:
:exclude-members: update, compute
:special-members: __new__

BinaryStatScores
^^^^^^^^^^^^^^^^
Expand Down
7 changes: 1 addition & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
import shutil
import sys

import torch

# this removes "Initializes internal Module state, shared by both nn.Module and ScriptModule." from the docs
torch.nn.Module.__init__.__doc__ = ""

import pt_lightning_sphinx_theme
from lightning_utilities.docs import fetch_external_assets
from lightning_utilities.docs.formatting import _transform_changelog
Expand Down Expand Up @@ -391,7 +386,7 @@ def _get_version_str():

autodoc_member_order = "groupwise"

autoclass_content = "both"
autoclass_content = "class"

autodoc_default_options = {
"members": True,
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Recall,
RecallAtFixedPrecision,
Specificity,
SpecificityAtSensitivity,
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
Expand Down Expand Up @@ -223,6 +224,7 @@
"SignalNoiseRatio",
"SpearmanCorrCoef",
"Specificity",
"SpecificityAtSensitivity",
"SpectralAngleMapper",
"SpectralDistortionIndex",
"SQuAD",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
BinarySpecificityAtSensitivity,
MulticlassSpecificityAtSensitivity,
MultilabelSpecificityAtSensitivity,
SpecificityAtSensitivity,
)
from torchmetrics.classification.stat_scores import (
BinaryStatScores,
Expand Down Expand Up @@ -201,6 +202,7 @@
"MulticlassSpecificityAtSensitivity",
"MultilabelSpecificityAtSensitivity",
"BinaryPrecisionAtFixedRecall",
"SpecificityAtSensitivity",
"MulticlassPrecisionAtFixedRecall",
"MultilabelPrecisionAtFixedRecall",
"PrecisionAtFixedRecall",
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.accuracy import _accuracy_reduce
from torchmetrics.metric import Metric
Expand Down Expand Up @@ -445,7 +446,7 @@ def plot(
return self._plot(val, ax)


class Accuracy:
class Accuracy(_ClassificationTaskWrapper):
r"""Compute `Accuracy`_.

.. math::
Expand Down
15 changes: 14 additions & 1 deletion src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.precision_recall_curve import (
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
Expand Down Expand Up @@ -460,7 +461,7 @@ def plot( # type: ignore[override]
return self._plot(val, ax)


class AUROC:
class AUROC(_ClassificationTaskWrapper):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_).

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
Expand Down Expand Up @@ -518,3 +519,15 @@ def __new__( # type: ignore[misc]
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelAUROC(num_labels, average, **kwargs)
raise ValueError(f"Task {task} not supported!")

def update(self, *args: Any, **kwargs: Any) -> None:
"""Update metric state."""
raise NotImplementedError(
f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric."
)

def compute(self) -> None:
"""Compute metric."""
raise NotImplementedError(
f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric."
)
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.precision_recall_curve import (
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
Expand Down Expand Up @@ -465,7 +466,7 @@ def plot( # type: ignore[override]
return self._plot(val, ax)


class AveragePrecision:
class AveragePrecision(_ClassificationTaskWrapper):
r"""Compute the average precision (AP) score.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the
Expand Down
Loading
Loading