From 325eb16cf06bc7f59474d05e82621cae1095729c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 3 Aug 2023 03:49:50 -1000 Subject: [PATCH] Convert classification wrapper to metrics (#1963) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> (cherry picked from commit f6b5890a4f00980f43f8cd4fb62bf8327b97bbb2) --- CHANGELOG.md | 3 +- docs/source/classification/accuracy.rst | 2 +- docs/source/classification/auroc.rst | 1 + .../classification/average_precision.rst | 1 + .../classification/calibration_error.rst | 1 + docs/source/classification/cohen_kappa.rst | 4 +-- .../classification/confusion_matrix.rst | 4 +-- docs/source/classification/exact_match.rst | 4 +-- docs/source/classification/f1_score.rst | 4 +-- docs/source/classification/fbeta_score.rst | 4 +-- .../classification/hamming_distance.rst | 4 +-- docs/source/classification/hinge_loss.rst | 1 + docs/source/classification/jaccard_index.rst | 4 +-- .../classification/matthews_corr_coef.rst | 4 +-- docs/source/classification/precision.rst | 1 + .../precision_at_fixed_recall.rst | 5 +++ .../classification/precision_recall_curve.rst | 1 + docs/source/classification/recall.rst | 1 + .../recall_at_fixed_precision.rst | 5 +++ docs/source/classification/roc.rst | 1 + docs/source/classification/specificity.rst | 1 + .../specificity_at_sensitivity.rst | 5 +++ docs/source/classification/stat_scores.rst | 4 +-- docs/source/conf.py | 7 +--- src/torchmetrics/__init__.py | 2 ++ src/torchmetrics/classification/__init__.py | 2 ++ src/torchmetrics/classification/accuracy.py | 3 +- src/torchmetrics/classification/auroc.py | 15 ++++++++- .../classification/average_precision.py | 3 +- src/torchmetrics/classification/base.py | 32 +++++++++++++++++++ .../classification/calibration_error.py | 3 +- .../classification/cohen_kappa.py | 3 +- .../classification/confusion_matrix.py | 3 +- .../classification/exact_match.py | 5 +-- src/torchmetrics/classification/f_beta.py | 12 +++---- src/torchmetrics/classification/hamming.py | 3 +- src/torchmetrics/classification/hinge.py | 3 +- src/torchmetrics/classification/jaccard.py | 3 +- .../classification/matthews_corrcoef.py | 3 +- .../classification/precision_fixed_recall.py | 5 +-- .../classification/precision_recall.py | 7 ++-- .../classification/precision_recall_curve.py | 3 +- .../classification/recall_fixed_precision.py | 5 +-- src/torchmetrics/classification/roc.py | 5 +-- .../classification/specificity.py | 3 +- .../classification/specificity_sensitivity.py | 5 +-- .../classification/stat_scores.py | 5 +-- .../unittests/classification/test_accuracy.py | 32 +++++++++++++------ tests/unittests/classification/test_auroc.py | 24 +++++++++++++- .../classification/test_average_precision.py | 23 +++++++++++++ .../classification/test_calibration_error.py | 27 +++++++++++++++- .../classification/test_cohen_kappa.py | 23 ++++++++++++- .../classification/test_confusion_matrix.py | 23 +++++++++++++ .../classification/test_exact_match.py | 23 ++++++++++++- tests/unittests/classification/test_f_beta.py | 28 ++++++++++++++++ .../classification/test_hamming_distance.py | 23 +++++++++++++ tests/unittests/classification/test_hinge.py | 23 ++++++++++++- .../unittests/classification/test_jaccard.py | 29 ++++++++++++++++- .../classification/test_matthews_corrcoef.py | 23 +++++++++++++ .../test_precision_fixed_recall.py | 23 +++++++++++++ .../classification/test_precision_recall.py | 28 ++++++++++++++++ .../test_precision_recall_curve.py | 23 +++++++++++++ .../test_recall_fixed_precision.py | 23 +++++++++++++ tests/unittests/classification/test_roc.py | 24 +++++++++++++- .../classification/test_specificity.py | 29 ++++++++++++++++- .../test_specificity_sensitivity.py | 23 +++++++++++++ .../classification/test_stat_scores.py | 29 ++++++++++++++++- 67 files changed, 620 insertions(+), 88 deletions(-) create mode 100644 src/torchmetrics/classification/base.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 449d20c9052..91f4853e3b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- - +- Changed all non-task specific classification metrics to be true subtypes of `Metric` ([#1963](https://github.com/Lightning-AI/torchmetrics/pull/1963)) ### Fixed diff --git a/docs/source/classification/accuracy.rst b/docs/source/classification/accuracy.rst index 075dfd1eb89..33fc519d57a 100644 --- a/docs/source/classification/accuracy.rst +++ b/docs/source/classification/accuracy.rst @@ -13,7 +13,7 @@ ________________ .. autoclass:: torchmetrics.Accuracy :noindex: :exclude-members: update, compute - + :special-members: __new__ BinaryAccuracy ^^^^^^^^^^^^^^ diff --git a/docs/source/classification/auroc.rst b/docs/source/classification/auroc.rst index 8c3a3d91ea6..619d239ef02 100644 --- a/docs/source/classification/auroc.rst +++ b/docs/source/classification/auroc.rst @@ -15,6 +15,7 @@ ________________ .. autoclass:: torchmetrics.AUROC :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryAUROC ^^^^^^^^^^^ diff --git a/docs/source/classification/average_precision.rst b/docs/source/classification/average_precision.rst index fb3466d421a..e8ed3dc61ec 100644 --- a/docs/source/classification/average_precision.rst +++ b/docs/source/classification/average_precision.rst @@ -13,6 +13,7 @@ ________________ .. autoclass:: torchmetrics.AveragePrecision :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryAveragePrecision ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/calibration_error.rst b/docs/source/classification/calibration_error.rst index 7ef898cbc4c..df47612b525 100644 --- a/docs/source/classification/calibration_error.rst +++ b/docs/source/classification/calibration_error.rst @@ -15,6 +15,7 @@ ________________ .. autoclass:: torchmetrics.CalibrationError :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryCalibrationError ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/cohen_kappa.rst b/docs/source/classification/cohen_kappa.rst index c3c7cad795f..09501a35a2a 100644 --- a/docs/source/classification/cohen_kappa.rst +++ b/docs/source/classification/cohen_kappa.rst @@ -12,12 +12,10 @@ Cohen Kappa Module Interface ________________ -CohenKappa -^^^^^^^^^^ - .. autoclass:: torchmetrics.CohenKappa :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryCohenKappa ^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/confusion_matrix.rst b/docs/source/classification/confusion_matrix.rst index 2185e83d6ac..94a5baee9e2 100644 --- a/docs/source/classification/confusion_matrix.rst +++ b/docs/source/classification/confusion_matrix.rst @@ -12,12 +12,10 @@ Confusion Matrix Module Interface ________________ -ConfusionMatrix -^^^^^^^^^^^^^^^ - .. autoclass:: torchmetrics.ConfusionMatrix :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryConfusionMatrix ^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/exact_match.rst b/docs/source/classification/exact_match.rst index d6cd67e5fbc..d1f7f313d35 100644 --- a/docs/source/classification/exact_match.rst +++ b/docs/source/classification/exact_match.rst @@ -10,12 +10,10 @@ Exact Match Module Interface ________________ -ExactMatch -^^^^^^^^^^^^^^^ - .. autoclass:: torchmetrics.ExactMatch :noindex: :exclude-members: update, compute + :special-members: __new__ MulticlassExactMatch ^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/f1_score.rst b/docs/source/classification/f1_score.rst index 7fa20cefdf2..64b5dd22cc4 100644 --- a/docs/source/classification/f1_score.rst +++ b/docs/source/classification/f1_score.rst @@ -10,12 +10,10 @@ F-1 Score Module Interface ________________ -F1Score -^^^^^^^ - .. autoclass:: torchmetrics.F1Score :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryF1Score ^^^^^^^^^^^^^ diff --git a/docs/source/classification/fbeta_score.rst b/docs/source/classification/fbeta_score.rst index 36fbaaf2899..22c99175fd5 100644 --- a/docs/source/classification/fbeta_score.rst +++ b/docs/source/classification/fbeta_score.rst @@ -12,12 +12,10 @@ F-Beta Score Module Interface ________________ -FBetaScore -^^^^^^^^^^ - .. autoclass:: torchmetrics.FBetaScore :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryFBetaScore ^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/hamming_distance.rst b/docs/source/classification/hamming_distance.rst index a775f8ca727..6559fa41308 100644 --- a/docs/source/classification/hamming_distance.rst +++ b/docs/source/classification/hamming_distance.rst @@ -10,12 +10,10 @@ Hamming Distance Module Interface ________________ -HammingDistance -^^^^^^^^^^^^^^^ - .. autoclass:: torchmetrics.HammingDistance :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryHammingDistance ^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/hinge_loss.rst b/docs/source/classification/hinge_loss.rst index 8976dda8348..1daa734cb75 100644 --- a/docs/source/classification/hinge_loss.rst +++ b/docs/source/classification/hinge_loss.rst @@ -13,6 +13,7 @@ ________________ .. autoclass:: torchmetrics.HingeLoss :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryHingeLoss ^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/jaccard_index.rst b/docs/source/classification/jaccard_index.rst index dd1217e6b6c..b0b312a38cb 100644 --- a/docs/source/classification/jaccard_index.rst +++ b/docs/source/classification/jaccard_index.rst @@ -10,12 +10,10 @@ Jaccard Index Module Interface ________________ -JaccardIndex -^^^^^^^^^^^^ - .. autoclass:: torchmetrics.JaccardIndex :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryJaccardIndex ^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/matthews_corr_coef.rst b/docs/source/classification/matthews_corr_coef.rst index 592d646231c..373b143bf19 100644 --- a/docs/source/classification/matthews_corr_coef.rst +++ b/docs/source/classification/matthews_corr_coef.rst @@ -12,12 +12,10 @@ Matthews Correlation Coefficient Module Interface ________________ -MatthewsCorrCoef -^^^^^^^^^^^^^^^^ - .. autoclass:: torchmetrics.MatthewsCorrCoef :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryMatthewsCorrCoef ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/precision.rst b/docs/source/classification/precision.rst index 4e83cd92932..b6c4668ba04 100644 --- a/docs/source/classification/precision.rst +++ b/docs/source/classification/precision.rst @@ -15,6 +15,7 @@ ________________ .. autoclass:: torchmetrics.Precision :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryPrecision ^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/precision_at_fixed_recall.rst b/docs/source/classification/precision_at_fixed_recall.rst index bd16e3d64e8..c91ec5ef078 100644 --- a/docs/source/classification/precision_at_fixed_recall.rst +++ b/docs/source/classification/precision_at_fixed_recall.rst @@ -10,6 +10,11 @@ Precision At Fixed Recall Module Interface ________________ +.. autoclass:: torchmetrics.PrecisionAtFixedRecall + :noindex: + :exclude-members: update, compute + :special-members: __new__ + BinaryPrecisionAtFixedRecall ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/precision_recall_curve.rst b/docs/source/classification/precision_recall_curve.rst index fd69bf3d163..7dafab0c06c 100644 --- a/docs/source/classification/precision_recall_curve.rst +++ b/docs/source/classification/precision_recall_curve.rst @@ -13,6 +13,7 @@ ________________ .. autoclass:: torchmetrics.PrecisionRecallCurve :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryPrecisionRecallCurve ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/recall.rst b/docs/source/classification/recall.rst index 7bd4694ffd3..52a7c6912b7 100644 --- a/docs/source/classification/recall.rst +++ b/docs/source/classification/recall.rst @@ -13,6 +13,7 @@ ________________ .. autoclass:: torchmetrics.Recall :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryRecall ^^^^^^^^^^^^ diff --git a/docs/source/classification/recall_at_fixed_precision.rst b/docs/source/classification/recall_at_fixed_precision.rst index 7e287782ae4..0946cf55424 100644 --- a/docs/source/classification/recall_at_fixed_precision.rst +++ b/docs/source/classification/recall_at_fixed_precision.rst @@ -10,6 +10,11 @@ Recall At Fixed Precision Module Interface ________________ +.. autoclass:: torchmetrics.RecallAtFixedPrecision + :noindex: + :exclude-members: update, compute + :special-members: __new__ + BinaryRecallAtFixedPrecision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/roc.rst b/docs/source/classification/roc.rst index dbf24ee096b..36e0bdcc1a8 100644 --- a/docs/source/classification/roc.rst +++ b/docs/source/classification/roc.rst @@ -13,6 +13,7 @@ ________________ .. autoclass:: torchmetrics.ROC :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryROC ^^^^^^^^^ diff --git a/docs/source/classification/specificity.rst b/docs/source/classification/specificity.rst index d284f7d4e52..73a259a76c5 100644 --- a/docs/source/classification/specificity.rst +++ b/docs/source/classification/specificity.rst @@ -13,6 +13,7 @@ ________________ .. autoclass:: torchmetrics.Specificity :noindex: :exclude-members: update, compute + :special-members: __new__ BinarySpecificity ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/specificity_at_sensitivity.rst b/docs/source/classification/specificity_at_sensitivity.rst index 7d1f5f3a26f..3fa0ec58666 100644 --- a/docs/source/classification/specificity_at_sensitivity.rst +++ b/docs/source/classification/specificity_at_sensitivity.rst @@ -10,6 +10,11 @@ Specificity At Sensitivity Module Interface ________________ +.. autoclass:: torchmetrics.SpecificityAtSensitivity + :noindex: + :exclude-members: update, compute + :special-members: __new__ + BinarySpecificityAtSensitivity ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/classification/stat_scores.rst b/docs/source/classification/stat_scores.rst index 01eeeaf3dcb..ea3e75ffb07 100644 --- a/docs/source/classification/stat_scores.rst +++ b/docs/source/classification/stat_scores.rst @@ -12,12 +12,10 @@ Stat Scores Module Interface ________________ -StatScores -^^^^^^^^^^ - .. autoclass:: torchmetrics.StatScores :noindex: :exclude-members: update, compute + :special-members: __new__ BinaryStatScores ^^^^^^^^^^^^^^^^ diff --git a/docs/source/conf.py b/docs/source/conf.py index 8c250067f97..5f4afd38615 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,11 +17,6 @@ import shutil import sys -import torch - -# this removes "Initializes internal Module state, shared by both nn.Module and ScriptModule." from the docs -torch.nn.Module.__init__.__doc__ = "" - import pt_lightning_sphinx_theme from lightning_utilities.docs import fetch_external_assets from lightning_utilities.docs.formatting import _transform_changelog @@ -391,7 +386,7 @@ def _get_version_str(): autodoc_member_order = "groupwise" -autoclass_content = "both" +autoclass_content = "class" autodoc_default_options = { "members": True, diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index df40ab3e9e2..6eeea86d232 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -52,6 +52,7 @@ Recall, RecallAtFixedPrecision, Specificity, + SpecificityAtSensitivity, StatScores, ) from torchmetrics.collections import MetricCollection # noqa: E402 @@ -221,6 +222,7 @@ "SignalNoiseRatio", "SpearmanCorrCoef", "Specificity", + "SpecificityAtSensitivity", "SpectralAngleMapper", "SpectralDistortionIndex", "SQuAD", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 48a76aa46e8..684f0f2ae9f 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -107,6 +107,7 @@ BinarySpecificityAtSensitivity, MulticlassSpecificityAtSensitivity, MultilabelSpecificityAtSensitivity, + SpecificityAtSensitivity, ) from torchmetrics.classification.stat_scores import ( BinaryStatScores, @@ -201,6 +202,7 @@ "MulticlassSpecificityAtSensitivity", "MultilabelSpecificityAtSensitivity", "BinaryPrecisionAtFixedRecall", + "SpecificityAtSensitivity", "MulticlassPrecisionAtFixedRecall", "MultilabelPrecisionAtFixedRecall", "PrecisionAtFixedRecall", diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 304db418dda..d69f91f9ef0 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.accuracy import _accuracy_reduce from torchmetrics.metric import Metric @@ -439,7 +440,7 @@ def plot( return self._plot(val, ax) -class Accuracy: +class Accuracy(_ClassificationTaskWrapper): r"""Compute `Accuracy`_. .. math:: diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index cb51e2b4f92..36f21c9c158 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -454,7 +455,7 @@ def plot( # type: ignore[override] return self._plot(val, ax) -class AUROC: +class AUROC(_ClassificationTaskWrapper): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for @@ -511,3 +512,15 @@ def __new__( # type: ignore[misc] raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelAUROC(num_labels, average, **kwargs) raise ValueError(f"Task {task} not supported!") + + def update(self, *args: Any, **kwargs: Any) -> None: + """Update metric state.""" + raise NotImplementedError( + f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric." + ) + + def compute(self) -> None: + """Compute metric.""" + raise NotImplementedError( + f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric." + ) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index e23ff6a8fd0..efa27ef283f 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -460,7 +461,7 @@ def plot( # type: ignore[override] return self._plot(val, ax) -class AveragePrecision: +class AveragePrecision(_ClassificationTaskWrapper): r"""Compute the average precision (AP) score. The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the diff --git a/src/torchmetrics/classification/base.py b/src/torchmetrics/classification/base.py new file mode 100644 index 00000000000..62d3eebcebd --- /dev/null +++ b/src/torchmetrics/classification/base.py @@ -0,0 +1,32 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from torchmetrics.metric import Metric + + +class _ClassificationTaskWrapper(Metric): + """Base class for wrapper metrics for classification tasks.""" + + def update(self, *args: Any, **kwargs: Any) -> None: + """Update metric state.""" + raise NotImplementedError( + f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric." + ) + + def compute(self) -> None: + """Compute metric.""" + raise NotImplementedError( + f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric." + ) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index d20bee41751..d1546bdd3bc 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.calibration_error import ( _binary_calibration_error_arg_validation, _binary_calibration_error_tensor_validation, @@ -328,7 +329,7 @@ def plot( return self._plot(val, ax) -class CalibrationError: +class CalibrationError(_ClassificationTaskWrapper): r"""`Top-label Calibration Error`_. The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index c5c97a7167d..236015f126b 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.confusion_matrix import BinaryConfusionMatrix, MulticlassConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import ( _binary_cohen_kappa_arg_validation, @@ -279,7 +280,7 @@ def plot( return self._plot(val, ax) -class CohenKappa: +class CohenKappa(_ClassificationTaskWrapper): r"""Calculate `Cohen's kappa score`_ that measures inter-annotator agreement. .. math:: diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 8858880e037..8c2ddac1766 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.confusion_matrix import ( _binary_confusion_matrix_arg_validation, _binary_confusion_matrix_compute, @@ -420,7 +421,7 @@ def plot( return fig, ax -class ConfusionMatrix: +class ConfusionMatrix(_ClassificationTaskWrapper): r"""Compute the `confusion matrix`_. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index d6e1e2fd0c2..e3b311f2e8d 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.exact_match import ( _exact_match_reduce, _multiclass_exact_match_update, @@ -355,7 +356,7 @@ def plot( return self._plot(val, ax) -class ExactMatch: +class ExactMatch(_ClassificationTaskWrapper): r"""Compute Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be @@ -405,4 +406,4 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelExactMatch(num_labels, threshold, **kwargs) - return None + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 52710cda19b..c324bc1bacd 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.f_beta import ( _binary_fbeta_score_arg_validation, @@ -994,7 +995,7 @@ def plot( return self._plot(val, ax) -class FBetaScore: +class FBetaScore(_ClassificationTaskWrapper): r"""Compute `F-score`_ metric. .. math:: @@ -1035,6 +1036,7 @@ def __new__( **kwargs: Any, ) -> Metric: """Initialize task metric.""" + task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -1051,12 +1053,10 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) + raise ValueError(f"Task {task} not supported!") -class F1Score: +class F1Score(_ClassificationTaskWrapper): r"""Compute F-1 score. .. math:: @@ -1112,4 +1112,4 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelF1Score(num_labels, threshold, average, **kwargs) - return None + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 3e2f01a4225..d84db174007 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.hamming import _hamming_distance_reduce from torchmetrics.metric import Metric @@ -451,7 +452,7 @@ def plot( return self._plot(val, ax) -class HammingDistance: +class HammingDistance(_ClassificationTaskWrapper): r"""Compute the average `Hamming distance`_ (also known as Hamming loss). .. math:: diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index b086cbe41f6..0762bdd3416 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.hinge import ( _binary_confusion_matrix_format, _binary_hinge_loss_arg_validation, @@ -309,7 +310,7 @@ def plot( return self._plot(val, ax) -class HingeLoss: +class HingeLoss(_ClassificationTaskWrapper): r"""Compute the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index f3856169a3a..28097c4419d 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, @@ -407,7 +408,7 @@ def plot( return self._plot(val, ax) -class JaccardIndex: +class JaccardIndex(_ClassificationTaskWrapper): r"""Calculate the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 31934c487df..1ef76b97525 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, @@ -360,7 +361,7 @@ def plot( return self._plot(val, ax) -class MatthewsCorrCoef: +class MatthewsCorrCoef(_ClassificationTaskWrapper): r"""Calculate `Matthews correlation coefficient`_ . This metric measures the general correlation or quality of a classification. diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index cc7b19bcc52..7d3059d37b8 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -459,7 +460,7 @@ def plot( # type: ignore[override] return self._plot(val, ax) -class PrecisionAtFixedRecall: +class PrecisionAtFixedRecall(_ClassificationTaskWrapper): r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for @@ -498,4 +499,4 @@ def __new__( # type: ignore[misc] return MultilabelPrecisionAtFixedRecall( num_labels, min_recall, thresholds, ignore_index, validate_args, **kwargs ) - return None # type: ignore[return-value] + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 27610abd2fb..1110ca82ad6 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce from torchmetrics.metric import Metric @@ -864,7 +865,7 @@ def plot( return self._plot(val, ax) -class Precision: +class Precision(_ClassificationTaskWrapper): r"""Compute `Precision`_. .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} @@ -922,10 +923,10 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelPrecision(num_labels, threshold, average, **kwargs) - return None + raise ValueError(f"Task {task} not supported!") -class Recall: +class Recall(_ClassificationTaskWrapper): r"""Compute `Recall`_. .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index b3121336483..f65fa44a1de 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.auroc import _reduce_auroc from torchmetrics.functional.classification.precision_recall_curve import ( _adjust_threshold_arg, @@ -581,7 +582,7 @@ def plot( ) -class PrecisionRecallCurve: +class PrecisionRecallCurve(_ClassificationTaskWrapper): r"""Compute the precision-recall curve. The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index 2b3a0264d8f..b29d28f27c5 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -454,7 +455,7 @@ def plot( # type: ignore[override] return self._plot(val, ax) -class RecallAtFixedPrecision: +class RecallAtFixedPrecision(_ClassificationTaskWrapper): r"""Compute the highest possible recall value given the minimum precision thresholds provided. This is done by first calculating the precision-recall curve for different thresholds and the find the recall for @@ -493,4 +494,4 @@ def __new__( # type: ignore[misc] return MultilabelRecallAtFixedPrecision( num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs ) - return None # type: ignore[return-value] + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index a61e3dee2b8..8b15697302a 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -472,7 +473,7 @@ def plot( ) -class ROC: +class ROC(_ClassificationTaskWrapper): r"""Compute the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at @@ -557,4 +558,4 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelROC(num_labels, **kwargs) - return None + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 283e82e001d..75c9b4517c5 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.specificity import _specificity_reduce from torchmetrics.metric import Metric @@ -429,7 +430,7 @@ def plot( return self._plot(val, ax) -class Specificity: +class Specificity(_ClassificationTaskWrapper): r"""Compute `Specificity`_. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} diff --git a/src/torchmetrics/classification/specificity_sensitivity.py b/src/torchmetrics/classification/specificity_sensitivity.py index 93749ba81c4..1b70daab33d 100644 --- a/src/torchmetrics/classification/specificity_sensitivity.py +++ b/src/torchmetrics/classification/specificity_sensitivity.py @@ -16,6 +16,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -320,7 +321,7 @@ def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] ) -class SpecificityAtSensitivity: +class SpecificityAtSensitivity(_ClassificationTaskWrapper): r"""Compute the higest possible specificity value given the minimum sensitivity thresholds provided. This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the @@ -359,4 +360,4 @@ def __new__( # type: ignore[misc] return MultilabelSpecificityAtSensitivity( num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs ) - return None # type: ignore[return-value] + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index de7d63ee13a..f1ebab548b5 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.classification.base import _ClassificationTaskWrapper from torchmetrics.functional.classification.stat_scores import ( _binary_stat_scores_arg_validation, _binary_stat_scores_compute, @@ -464,7 +465,7 @@ def compute(self) -> Tensor: return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -class StatScores: +class StatScores(_ClassificationTaskWrapper): r"""Compute the number of true positives, false positives, true negatives, false negatives and the support. This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the @@ -517,4 +518,4 @@ def __new__( if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return MultilabelStatScores(num_labels, threshold, average, **kwargs) - return None + raise ValueError(f"Task {task} not supported!") diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 9189b897b74..7501ee2f4ae 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -26,6 +26,7 @@ multiclass_accuracy, multilabel_accuracy, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases @@ -66,16 +67,6 @@ def _sklearn_accuracy_binary(preds, target, ignore_index, multidim_average): return np.stack(res) -def test_accuracy_raises_invalid_task(): - """Tests accuracy task enum from Accuracy.""" - task = "NotValidTask" - ignore_index = None - multidim_average = "global" - - with pytest.raises(ValueError, match=r"Invalid *"): - Accuracy(threshold=THRESHOLD, task=task, ignore_index=ignore_index, multidim_average=multidim_average) - - def test_accuracy_functional_raises_invalid_task(): """Tests accuracy task enum from functional.accuracy.""" preds, target = _input_binary @@ -557,3 +548,24 @@ def test_corner_cases(): metric = MulticlassAccuracy(num_classes=3, average="macro", ignore_index=0) res = metric(preds, target) assert res == 1.0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryAccuracy, {"task": "binary"}), + (MulticlassAccuracy, {"task": "multiclass", "num_classes": 3}), + (MultilabelAccuracy, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=Accuracy): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index e51f564b02f..2093b1bdd95 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -19,9 +19,10 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score -from torchmetrics.classification.auroc import BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc from torchmetrics.functional.classification.roc import binary_roc +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -407,3 +408,24 @@ def test_corner_case_max_fpr(max_fpr): target = torch.tensor([1, 1, 1, 1]) metric = BinaryAUROC(max_fpr=max_fpr) assert metric(preds, target) == 0.0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryAUROC, {"task": "binary"}), + (MulticlassAUROC, {"task": "multiclass", "num_classes": 3}), + (MultilabelAUROC, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=AUROC): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 91a06329c24..1a4c507014b 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -20,6 +20,7 @@ from scipy.special import softmax from sklearn.metrics import average_precision_score as sk_average_precision_score from torchmetrics.classification.average_precision import ( + AveragePrecision, BinaryAveragePrecision, MulticlassAveragePrecision, MultilabelAveragePrecision, @@ -30,6 +31,7 @@ multilabel_average_precision, ) from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -394,3 +396,24 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryAveragePrecision, {"task": "binary"}), + (MulticlassAveragePrecision, {"task": "multiclass", "num_classes": 3}), + (MultilabelAveragePrecision, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=AveragePrecision): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index fbd5568c0ae..60cad273bbe 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -19,11 +19,16 @@ from netcal.metrics import ECE, MCE from scipy.special import expit as sigmoid from scipy.special import softmax -from torchmetrics.classification.calibration_error import BinaryCalibrationError, MulticlassCalibrationError +from torchmetrics.classification.calibration_error import ( + BinaryCalibrationError, + CalibrationError, + MulticlassCalibrationError, +) from torchmetrics.functional.classification.calibration_error import ( binary_calibration_error, multiclass_calibration_error, ) +from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_13 from unittests import NUM_CLASSES @@ -269,3 +274,23 @@ def test_corner_case_due_to_dtype(): assert np.allclose( ECE(100).measure(preds.numpy(), target.numpy()), binary_calibration_error(preds, target, n_bins=100) ), "The metric should be close to the netcal implementation" + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryCalibrationError, {"task": "binary"}), + (MulticlassCalibrationError, {"task": "multiclass", "num_classes": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=CalibrationError): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index b81a1d20d07..c39336a8a37 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -18,8 +18,9 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa -from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, MulticlassCohenKappa +from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases @@ -225,3 +226,23 @@ def test_multiclass_confusion_matrix_dtypes_gpu(self, inputs, dtype): metric_args={"num_classes": NUM_CLASSES}, dtype=dtype, ) + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryCohenKappa, {"task": "binary"}), + (MulticlassCohenKappa, {"task": "multiclass", "num_classes": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=CohenKappa): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 4914983a07f..78f201c9121 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -20,6 +20,7 @@ from sklearn.metrics import confusion_matrix as sk_confusion_matrix from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, + ConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix, ) @@ -28,6 +29,7 @@ multiclass_confusion_matrix, multilabel_confusion_matrix, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -362,3 +364,24 @@ def test_warning_on_nan(): match=".* NaN values found in confusion matrix have been replaced with zeros.", ): multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryConfusionMatrix, {"task": "binary"}), + (MulticlassConfusionMatrix, {"task": "multiclass", "num_classes": 3}), + (MultilabelConfusionMatrix, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=ConfusionMatrix): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index e2ce18a61f3..d4e82c76436 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -17,8 +17,9 @@ import pytest import torch from scipy.special import expit as sigmoid -from torchmetrics.classification.exact_match import MulticlassExactMatch, MultilabelExactMatch +from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _multiclass_cases, _multilabel_cases @@ -273,3 +274,23 @@ def test_multilabel_exact_match_half_gpu(self, inputs, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (MulticlassExactMatch, {"task": "multiclass", "num_classes": 3}), + (MultilabelExactMatch, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=ExactMatch): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 3a692d33c1e..b23ecfb8792 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -24,6 +24,8 @@ from torchmetrics.classification.f_beta import ( BinaryF1Score, BinaryFBetaScore, + F1Score, + FBetaScore, MulticlassF1Score, MulticlassFBetaScore, MultilabelF1Score, @@ -37,6 +39,7 @@ multilabel_f1_score, multilabel_fbeta_score, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -577,3 +580,28 @@ def test_corner_case(): f1_score = MulticlassF1Score(num_classes=i, average="macro") res = f1_score(preds, target) assert res == torch.tensor([0.77777779]) + + +@pytest.mark.parametrize( + ("metric", "kwargs", "base_metric"), + [ + (BinaryF1Score, {"task": "binary"}, F1Score), + (MulticlassF1Score, {"task": "multiclass", "num_classes": 3}, F1Score), + (MultilabelF1Score, {"task": "multilabel", "num_labels": 3}, F1Score), + (None, {"task": "not_valid_task"}, F1Score), + (BinaryFBetaScore, {"task": "binary", "beta": 2.0}, FBetaScore), + (MulticlassFBetaScore, {"task": "multiclass", "num_classes": 3, "beta": 2.0}, FBetaScore), + (MultilabelFBetaScore, {"task": "multilabel", "num_labels": 3, "beta": 2.0}, FBetaScore), + (None, {"task": "not_valid_task"}, FBetaScore), + ], +) +def test_wrapper_class(metric, kwargs, base_metric): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 0a59a7902e6..8da05c4cd8b 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -21,6 +21,7 @@ from sklearn.metrics import hamming_loss as sk_hamming_loss from torchmetrics.classification.hamming import ( BinaryHammingDistance, + HammingDistance, MulticlassHammingDistance, MultilabelHammingDistance, ) @@ -29,6 +30,7 @@ multiclass_hamming_distance, multilabel_hamming_distance, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -507,3 +509,24 @@ def test_multilabel_hamming_distance_dtype_gpu(self, inputs, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryHammingDistance, {"task": "binary"}), + (MulticlassHammingDistance, {"task": "multiclass", "num_classes": 3}), + (MultilabelHammingDistance, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=HammingDistance): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 5aec63aa69f..536cf4698fa 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -20,8 +20,9 @@ from scipy.special import softmax from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder -from torchmetrics.classification.hinge import BinaryHingeLoss, MulticlassHingeLoss +from torchmetrics.classification.hinge import BinaryHingeLoss, HingeLoss, MulticlassHingeLoss from torchmetrics.functional.classification.hinge import binary_hinge_loss, multiclass_hinge_loss +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases @@ -227,3 +228,23 @@ def test_multiclass_hinge_loss_dtype_gpu(self, inputs, dtype): metric_args={"num_classes": NUM_CLASSES}, dtype=dtype, ) + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryHingeLoss, {"task": "binary"}), + (MulticlassHingeLoss, {"task": "multiclass", "num_classes": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=HingeLoss): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index c6133acb435..49083bab908 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -19,12 +19,18 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import jaccard_score as sk_jaccard_index -from torchmetrics.classification.jaccard import BinaryJaccardIndex, MulticlassJaccardIndex, MultilabelJaccardIndex +from torchmetrics.classification.jaccard import ( + BinaryJaccardIndex, + JaccardIndex, + MulticlassJaccardIndex, + MultilabelJaccardIndex, +) from torchmetrics.functional.classification.jaccard import ( binary_jaccard_index, multiclass_jaccard_index, multilabel_jaccard_index, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -359,3 +365,24 @@ def test_corner_case(): assert torch.allclose(res, torch.ones_like(res)) res = multiclass_jaccard_index(pred, target, num_classes=10, average="none") assert torch.allclose(res, out) + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryJaccardIndex, {"task": "binary"}), + (MulticlassJaccardIndex, {"task": "multiclass", "num_classes": 3}), + (MultilabelJaccardIndex, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=JaccardIndex): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 795fc503479..18cdc141358 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -20,6 +20,7 @@ from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef from torchmetrics.classification.matthews_corrcoef import ( BinaryMatthewsCorrCoef, + MatthewsCorrCoef, MulticlassMatthewsCorrCoef, MultilabelMatthewsCorrCoef, ) @@ -28,6 +29,7 @@ multiclass_matthews_corrcoef, multilabel_matthews_corrcoef, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -361,3 +363,24 @@ def test_corner_cases(metric_fn, preds, target, expected): """Test the corner cases of perfect classifiers or completely random classifiers that they work as expected.""" out = metric_fn(preds, target) assert out == expected + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryMatthewsCorrCoef, {"task": "binary"}), + (MulticlassMatthewsCorrCoef, {"task": "multiclass", "num_classes": 3}), + (MultilabelMatthewsCorrCoef, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=MatthewsCorrCoef): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index e3d9ee1bf1f..0d782bd14fe 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -24,12 +24,14 @@ BinaryPrecisionAtFixedRecall, MulticlassPrecisionAtFixedRecall, MultilabelPrecisionAtFixedRecall, + PrecisionAtFixedRecall, ) from torchmetrics.functional.classification.precision_fixed_recall import ( binary_precision_at_fixed_recall, multiclass_precision_at_fixed_recall, multilabel_precision_at_fixed_recall, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -420,3 +422,24 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(min_recall=0.5, thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryPrecisionAtFixedRecall, {"task": "binary", "min_recall": 0.5}), + (MulticlassPrecisionAtFixedRecall, {"task": "multiclass", "num_classes": 3, "min_recall": 0.5}), + (MultilabelPrecisionAtFixedRecall, {"task": "multilabel", "num_labels": 3, "min_recall": 0.5}), + (None, {"task": "not_valid_task", "min_recall": 0.5}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=PrecisionAtFixedRecall): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 5fda29f1e30..e4a2ddff658 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -28,6 +28,8 @@ MulticlassRecall, MultilabelPrecision, MultilabelRecall, + Precision, + Recall, ) from torchmetrics.functional.classification.precision_recall import ( binary_precision, @@ -37,6 +39,7 @@ multilabel_precision, multilabel_recall, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -585,3 +588,28 @@ def test_corner_case(): metric = MulticlassRecall(num_classes=3, average="macro", ignore_index=0) res = metric(preds, target) assert res == 1.0 + + +@pytest.mark.parametrize( + ("metric", "kwargs", "base_metric"), + [ + (BinaryPrecision, {"task": "binary"}, Precision), + (MulticlassPrecision, {"task": "multiclass", "num_classes": 3}, Precision), + (MultilabelPrecision, {"task": "multilabel", "num_labels": 3}, Precision), + (None, {"task": "not_valid_task"}, Precision), + (BinaryRecall, {"task": "binary"}, Recall), + (MulticlassRecall, {"task": "multiclass", "num_classes": 3}, Recall), + (MultilabelRecall, {"task": "multilabel", "num_labels": 3}, Recall), + (None, {"task": "not_valid_task"}, Recall), + ], +) +def test_wrapper_class(metric, kwargs, base_metric): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 78329fb7d7a..11d7915927e 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -25,12 +25,14 @@ BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, MultilabelPrecisionRecallCurve, + PrecisionRecallCurve, ) from torchmetrics.functional.classification.precision_recall_curve import ( binary_precision_recall_curve, multiclass_precision_recall_curve, multilabel_precision_recall_curve, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -430,3 +432,24 @@ def test_empty_state_dict(metric, thresholds): """Test that metric have an empty state dict.""" m = metric(thresholds=thresholds) assert m.state_dict() == {}, "Metric state dict should be empty." + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryPrecisionRecallCurve, {"task": "binary"}), + (MulticlassPrecisionRecallCurve, {"task": "multiclass", "num_classes": 3}), + (MultilabelPrecisionRecallCurve, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=PrecisionRecallCurve): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index c5c51848d07..cf88af1aed8 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -24,12 +24,14 @@ BinaryRecallAtFixedPrecision, MulticlassRecallAtFixedPrecision, MultilabelRecallAtFixedPrecision, + RecallAtFixedPrecision, ) from torchmetrics.functional.classification.recall_fixed_precision import ( binary_recall_at_fixed_precision, multiclass_recall_at_fixed_precision, multilabel_recall_at_fixed_precision, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -420,3 +422,24 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(min_precision=0.5, thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryRecallAtFixedPrecision, {"task": "binary", "min_precision": 0.5}), + (MulticlassRecallAtFixedPrecision, {"task": "multiclass", "num_classes": 3, "min_precision": 0.5}), + (MultilabelRecallAtFixedPrecision, {"task": "multilabel", "num_labels": 3, "min_precision": 0.5}), + (None, {"task": "not_valid_task", "min_precision": 0.5}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=RecallAtFixedPrecision): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 339079801cf..41f5f1d9253 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -19,8 +19,9 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_curve as sk_roc_curve -from torchmetrics.classification.roc import BinaryROC, MulticlassROC, MultilabelROC +from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -373,3 +374,24 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryROC, {"task": "binary"}), + (MulticlassROC, {"task": "multiclass", "num_classes": 3}), + (MultilabelROC, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=ROC): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 12bdd57edb1..f7070d91560 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -19,12 +19,18 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from torch import Tensor, tensor -from torchmetrics.classification.specificity import BinarySpecificity, MulticlassSpecificity, MultilabelSpecificity +from torchmetrics.classification.specificity import ( + BinarySpecificity, + MulticlassSpecificity, + MultilabelSpecificity, + Specificity, +) from torchmetrics.functional.classification.specificity import ( binary_specificity, multiclass_specificity, multilabel_specificity, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -565,3 +571,24 @@ def test_corner_cases(): metric = MulticlassSpecificity(num_classes=3, average="macro", ignore_index=0) res = metric(preds, target) assert res == 1.0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinarySpecificity, {"task": "binary"}), + (MulticlassSpecificity, {"task": "multiclass", "num_classes": 3}), + (MultilabelSpecificity, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=Specificity): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index dae52dfd6d2..cdb13c52bb4 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -24,6 +24,7 @@ BinarySpecificityAtSensitivity, MulticlassSpecificityAtSensitivity, MultilabelSpecificityAtSensitivity, + SpecificityAtSensitivity, ) from torchmetrics.functional.classification.specificity_sensitivity import ( _convert_fpr_to_specificity, @@ -31,6 +32,7 @@ multiclass_specificity_at_sensitivity, multilabel_specificity_at_sensitivity, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -455,3 +457,24 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(min_sensitivity=0.5, thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinarySpecificityAtSensitivity, {"task": "binary", "min_sensitivity": 0.5}), + (MulticlassSpecificityAtSensitivity, {"task": "multiclass", "num_classes": 3, "min_sensitivity": 0.5}), + (MultilabelSpecificityAtSensitivity, {"task": "multilabel", "num_labels": 3, "min_sensitivity": 0.5}), + (None, {"task": "not_valid_task", "min_sensitivity": 0.5}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=SpecificityAtSensitivity): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index f8b22a392fb..f6b43fda6c7 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -18,12 +18,18 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, multiclass_stat_scores, multilabel_stat_scores, ) +from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases @@ -534,3 +540,24 @@ def test_multilabel_stat_scores_dtype_gpu(self, inputs, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryStatScores, {"task": "binary"}), + (MulticlassStatScores, {"task": "multiclass", "num_classes": 3}), + (MultilabelStatScores, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=StatScores): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric)