From 02f050a7ba912d7340649d10055a0944fd2a2da0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Oct 2024 19:07:45 +0200 Subject: [PATCH] Newmetric: NRMSE (#2442) --- CHANGELOG.md | 2 +- docs/source/conf.py | 3 + docs/source/links.rst | 1 + .../normalized_root_mean_squared_error.rst | 21 ++ requirements/_devel.txt | 1 + requirements/regression_test.txt | 1 + src/torchmetrics/__init__.py | 36 +-- src/torchmetrics/functional/__init__.py | 24 +- .../functional/regression/__init__.py | 8 +- .../functional/regression/nrmse.py | 106 +++++++ src/torchmetrics/regression/__init__.py | 8 +- src/torchmetrics/regression/nrmse.py | 279 ++++++++++++++++++ tests/README.md | 6 +- tests/unittests/regression/test_mean_error.py | 178 +++++++++-- tests/unittests/utilities/test_plot.py | 2 + 15 files changed, 621 insertions(+), 55 deletions(-) create mode 100644 docs/source/regression/normalized_root_mean_squared_error.rst create mode 100644 requirements/regression_test.txt create mode 100644 src/torchmetrics/functional/regression/nrmse.py create mode 100644 src/torchmetrics/regression/nrmse.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8736f561615..99741831e6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) ### Changed diff --git a/docs/source/conf.py b/docs/source/conf.py index 5442f9641a9..81f842e7a12 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -447,6 +447,9 @@ def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001 "https://aclanthology.org/W17-4770", # A wavelet transform method to merge Landsat TM and SPOT panchromatic data "https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013", + # Improved normalization of time-lapse seismic data using normalized root mean square repeatability data ... + # ... to improve automatic production and seismic history matching in the Nelson field + "https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109", # todo: these links seems to be unstable, referring to .devcontainer "https://code.visualstudio.com", "https://code.visualstudio.com/.*", diff --git a/docs/source/links.rst b/docs/source/links.rst index 2e9b222f28f..b7a4f63565e 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -171,6 +171,7 @@ .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 +.. _Normalized Root Mean Squared Error: https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109 .. _Generalized Dice Score: https://arxiv.org/abs/1707.03237 .. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html diff --git a/docs/source/regression/normalized_root_mean_squared_error.rst b/docs/source/regression/normalized_root_mean_squared_error.rst new file mode 100644 index 00000000000..7bbc2f392d5 --- /dev/null +++ b/docs/source/regression/normalized_root_mean_squared_error.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Normalized Root Mean Squared Error (NRMSE) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Regression + +.. include:: ../links.rst + +########################################## +Normalized Root Mean Squared Error (NRMSE) +########################################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.NormalizedRootMeanSquaredError + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.normalized_root_mean_squared_error diff --git a/requirements/_devel.txt b/requirements/_devel.txt index 596cc138133..6a8ea2b8e7f 100644 --- a/requirements/_devel.txt +++ b/requirements/_devel.txt @@ -20,3 +20,4 @@ -r classification_test.txt -r nominal_test.txt -r segmentation_test.txt +-r regression_test.txt diff --git a/requirements/regression_test.txt b/requirements/regression_test.txt new file mode 100644 index 00000000000..859605fda3b --- /dev/null +++ b/requirements/regression_test.txt @@ -0,0 +1 @@ +permetrics==2.0.0 diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 2fa370cb1c9..a6105df3480 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -114,6 +114,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -158,25 +159,23 @@ ) __all__ = [ - "functional", - "Accuracy", "AUROC", + "Accuracy", "AveragePrecision", "BLEUScore", "BootStrapper", + "CHRFScore", "CalibrationError", "CatMetric", - "ClasswiseWrapper", "CharErrorRate", - "CHRFScore", - "ConcordanceCorrCoef", + "ClasswiseWrapper", "CohenKappa", + "ConcordanceCorrCoef", "ConfusionMatrix", "CosineSimilarity", "CramersV", "CriticalSuccessIndex", "Dice", - "TweedieDevianceScore", "ErrorRelativeGlobalDimensionlessSynthesis", "ExactMatch", "ExplainedVariance", @@ -187,8 +186,8 @@ "HammingDistance", "HingeLoss", "JaccardIndex", - "KendallRankCorrCoef", "KLDivergence", + "KendallRankCorrCoef", "LogCoshError", "MatchErrorRate", "MatthewsCorrCoef", @@ -201,14 +200,16 @@ "Metric", "MetricCollection", "MetricTracker", - "MinkowskiDistance", "MinMaxMetric", "MinMetric", + "MinkowskiDistance", "ModifiedPanopticQuality", + "MultiScaleStructuralSimilarityIndexMeasure", "MultioutputWrapper", "MultitaskWrapper", - "MultiScaleStructuralSimilarityIndexMeasure", + "NormalizedRootMeanSquaredError", "PanopticQuality", + "PeakSignalNoiseRatio", "PearsonCorrCoef", "PearsonsContingencyCoefficient", "PermutationInvariantTraining", @@ -216,8 +217,8 @@ "Precision", "PrecisionAtFixedRecall", "PrecisionRecallCurve", - "PeakSignalNoiseRatio", "R2Score", + "ROC", "Recall", "RecallAtFixedPrecision", "RelativeAverageSpectralError", @@ -228,37 +229,38 @@ "RetrievalMRR", "RetrievalNormalizedDCG", "RetrievalPrecision", - "RetrievalRecall", - "RetrievalRPrecision", "RetrievalPrecisionRecallCurve", + "RetrievalRPrecision", + "RetrievalRecall", "RetrievalRecallAtFixedPrecision", - "ROC", "RootMeanSquaredErrorUsingSlidingWindow", "RunningMean", "RunningSum", + "SQuAD", "SacreBLEUScore", - "SignalDistortionRatio", "ScaleInvariantSignalDistortionRatio", "ScaleInvariantSignalNoiseRatio", + "SensitivityAtSpecificity", + "SignalDistortionRatio", "SignalNoiseRatio", "SpearmanCorrCoef", "Specificity", "SpecificityAtSensitivity", - "SensitivityAtSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", - "SQuAD", - "StructuralSimilarityIndexMeasure", "StatScores", + "StructuralSimilarityIndexMeasure", "SumMetric", "SymmetricMeanAbsolutePercentageError", "TheilsU", "TotalVariation", "TranslationEditRate", "TschuprowsT", + "TweedieDevianceScore", "UniversalImageQualityIndex", "WeightedMeanAbsolutePercentageError", "WordErrorRate", "WordInfoLost", "WordInfoPreserved", + "functional", ] diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 30a7145aa71..7de7f261867 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -100,6 +100,7 @@ mean_squared_error, mean_squared_log_error, minkowski_distance, + normalized_root_mean_squared_error, pearson_corrcoef, r2_score, relative_squared_error, @@ -146,14 +147,13 @@ "calibration_error", "char_error_rate", "chrf_score", - "concordance_corrcoef", "cohen_kappa", + "concordance_corrcoef", "confusion_matrix", "cosine_similarity", "cramers_v", "cramers_v_matrix", "critical_success_index", - "tweedie_deviance_score", "dice", "error_relative_global_dimensionless_synthesis", "exact_match", @@ -177,12 +177,14 @@ "mean_squared_log_error", "minkowski_distance", "multiscale_structural_similarity_index_measure", + "normalized_root_mean_squared_error", "pairwise_cosine_similarity", "pairwise_euclidean_distance", "pairwise_linear_similarity", "pairwise_manhattan_distance", "pairwise_minkowski_distance", "panoptic_quality", + "peak_signal_noise_ratio", "pearson_corrcoef", "pearsons_contingency_coefficient", "pearsons_contingency_coefficient_matrix", @@ -190,10 +192,11 @@ "perplexity", "pit_permutate", "precision", + "precision_at_fixed_recall", "precision_recall_curve", - "peak_signal_noise_ratio", "r2_score", "recall", + "recall_at_fixed_precision", "relative_average_spectral_error", "relative_squared_error", "retrieval_average_precision", @@ -201,24 +204,27 @@ "retrieval_hit_rate", "retrieval_normalized_dcg", "retrieval_precision", + "retrieval_precision_recall_curve", "retrieval_r_precision", "retrieval_recall", "retrieval_reciprocal_rank", - "retrieval_precision_recall_curve", "roc", "root_mean_squared_error_using_sliding_window", "rouge_score", "sacre_bleu_score", - "signal_distortion_ratio", "scale_invariant_signal_distortion_ratio", "scale_invariant_signal_noise_ratio", + "sensitivity_at_specificity", + "signal_distortion_ratio", "signal_noise_ratio", "spearman_corrcoef", "specificity", + "specificity_at_sensitivity", + "spectral_angle_mapper", "spectral_distortion_index", "squad", - "structural_similarity_index_measure", "stat_scores", + "structural_similarity_index_measure", "symmetric_mean_absolute_percentage_error", "theils_u", "theils_u_matrix", @@ -226,14 +232,10 @@ "translation_edit_rate", "tschuprows_t", "tschuprows_t_matrix", + "tweedie_deviance_score", "universal_image_quality_index", - "spectral_angle_mapper", "weighted_mean_absolute_percentage_error", "word_error_rate", "word_information_lost", "word_information_preserved", - "precision_at_fixed_recall", - "recall_at_fixed_precision", - "sensitivity_at_specificity", - "specificity_at_sensitivity", ] diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index c2dab8c5f59..063fbc059e3 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.functional.regression.mape import mean_absolute_percentage_error from torchmetrics.functional.regression.minkowski import minkowski_distance from torchmetrics.functional.regression.mse import mean_squared_error +from torchmetrics.functional.regression.nrmse import normalized_root_mean_squared_error from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.functional.regression.r2 import r2_score from torchmetrics.functional.regression.rse import relative_squared_error @@ -39,13 +40,14 @@ "kendall_rank_corrcoef", "kl_divergence", "log_cosh_error", - "mean_squared_log_error", "mean_absolute_error", - "mean_squared_error", - "pearson_corrcoef", "mean_absolute_percentage_error", "mean_absolute_percentage_error", + "mean_squared_error", + "mean_squared_log_error", "minkowski_distance", + "normalized_root_mean_squared_error", + "pearson_corrcoef", "r2_score", "relative_squared_error", "spearman_corrcoef", diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py new file mode 100644 index 00000000000..52cae36adb0 --- /dev/null +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -0,0 +1,106 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.mse import _mean_squared_error_update + + +def _normalized_root_mean_squared_error_update( + preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean" +) -> Tuple[Tensor, int, Tensor]: + """Updates and returns the sum of squared errors and the number of observations for NRMSE computation. + + Args: + preds: Predicted tensor + target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" + + """ + sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs) + + target = target.view(-1) if num_outputs == 1 else target + if normalization == "mean": + denom = torch.mean(target, dim=0) + elif normalization == "range": + denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values + elif normalization == "std": + denom = torch.std(target, correction=0, dim=0) + elif normalization == "l2": + denom = torch.norm(target, p=2, dim=0) + else: + raise ValueError( + f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2' but got {normalization}" + ) + return sum_squared_error, num_obs, denom + + +def _normalized_root_mean_squared_error_compute( + sum_squared_error: Tensor, num_obs: Union[int, Tensor], denom: Tensor +) -> Tensor: + """Calculates RMSE and normalizes it.""" + rmse = torch.sqrt(sum_squared_error / num_obs) + return rmse / denom + + +def normalized_root_mean_squared_error( + preds: Tensor, + target: Tensor, + normalization: Literal["mean", "range", "std", "l2"] = "mean", + num_outputs: int = 1, +) -> Tensor: + """Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. + + Args: + preds: estimated labels + target: ground truth labels + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds + to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the + target or the L2 norm of the target. + num_outputs: Number of outputs in multioutput setting + + Return: + Tensor with the NRMSE score + + Example: + >>> import torch + >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error + >>> preds = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> normalized_root_mean_squared_error(preds, target, normalization="mean") + tensor(0.4000) + >>> normalized_root_mean_squared_error(preds, target, normalization="range") + tensor(0.2500) + >>> normalized_root_mean_squared_error(preds, target, normalization="std") + tensor(0.6030) + >>> normalized_root_mean_squared_error(preds, target, normalization="l2") + tensor(0.1667) + + Example (multioutput): + >>> import torch + >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error + >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]]) + >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]]) + >>> normalized_root_mean_squared_error(preds, target, normalization="mean", num_outputs=2) + tensor([0.2981, 0.2222]) + + """ + sum_squared_error, num_obs, denom = _normalized_root_mean_squared_error_update( + preds, target, num_outputs=num_outputs, normalization=normalization + ) + return _normalized_root_mean_squared_error_compute(sum_squared_error, num_obs, denom) diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index 03ba8023a10..6a41c01bcdb 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.regression.mape import MeanAbsolutePercentageError from torchmetrics.regression.minkowski import MinkowskiDistance from torchmetrics.regression.mse import MeanSquaredError +from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.pearson import PearsonCorrCoef from torchmetrics.regression.r2 import R2Score from torchmetrics.regression.rse import RelativeSquaredError @@ -36,14 +37,15 @@ "CosineSimilarity", "CriticalSuccessIndex", "ExplainedVariance", - "KendallRankCorrCoef", "KLDivergence", + "KendallRankCorrCoef", "LogCoshError", - "MeanSquaredLogError", "MeanAbsoluteError", "MeanAbsolutePercentageError", - "MinkowskiDistance", "MeanSquaredError", + "MeanSquaredLogError", + "MinkowskiDistance", + "NormalizedRootMeanSquaredError", "PearsonCorrCoef", "R2Score", "RelativeSquaredError", diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py new file mode 100644 index 00000000000..62562803542 --- /dev/null +++ b/src/torchmetrics/regression/nrmse.py @@ -0,0 +1,279 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.nrmse import ( + _mean_squared_error_update, + _normalized_root_mean_squared_error_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["NormalizedRootMeanSquaredError.plot"] + + +def _final_aggregation( + min_val: Tensor, + max_val: Tensor, + mean_val: Tensor, + var_val: Tensor, + target_squared: Tensor, + total: Tensor, + normalization: Literal["mean", "range", "std", "l2"] = "mean", +) -> Tensor: + """In the case of multiple devices we need to aggregate the statistics from the different devices.""" + if len(min_val) == 1: + if normalization == "mean": + return mean_val[0] + if normalization == "range": + return max_val[0] - min_val[0] + if normalization == "std": + return var_val[0] + if normalization == "l2": + return target_squared[0] + + min_val_1, max_val_1, mean_val_1, var_val_1, target_squared_1, total_1 = ( + min_val[0], + max_val[0], + mean_val[0], + var_val[0], + target_squared[0], + total[0], + ) + for i in range(1, len(min_val)): + min_val_2, max_val_2, mean_val_2, var_val_2, target_squared_2, total_2 = ( + min_val[i], + max_val[i], + mean_val[i], + var_val[i], + target_squared[i], + total[i], + ) + # update total and mean + total = total_1 + total_2 + mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total + + # update variance + _temp = (total_1 + 1) * mean - total_1 * mean_val_1 + var_val_1 += (_temp - mean_val_1) * (_temp - mean) - (_temp - mean) ** 2 + _temp = (total_2 + 1) * mean - total_2 * mean_val_2 + var_val_2 += (_temp - mean_val_2) * (_temp - mean) - (_temp - mean) ** 2 + var = var_val_1 + var_val_2 + + # update min and max and target squared + min_val = torch.min(min_val_1, min_val_2) + max_val = torch.max(max_val_1, max_val_2) + target_squared = target_squared_1 + target_squared_2 + + if normalization == "mean": + return mean + if normalization == "range": + return max_val - min_val + if normalization == "std": + return (var / total).sqrt() + return target_squared.sqrt() + + +class NormalizedRootMeanSquaredError(Metric): + r"""Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. + + The metric is defined as: + + .. math:: + \text{NRMSE} = \frac{\text{RMSE}}{\text{denom}} + + where RMSE is the root mean squared error and `denom` is the normalization factor. The normalization factor can be + either be the mean, range, standard deviation or L2 norm of the target, which can be set using the `normalization` + argument. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``nrmse`` (:class:`~torch.Tensor`): A tensor with the mean squared error + + Args: + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds + to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the + target or the L2 norm of the target. + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example:: + Single output normalized root mean squared error computation: + + >>> import torch + >>> from torchmetrics import NormalizedRootMeanSquaredError + >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) + >>> nrmse = NormalizedRootMeanSquaredError(normalization="mean") + >>> nrmse(preds, target) + tensor(0.1919) + >>> nrmse = NormalizedRootMeanSquaredError(normalization="range") + >>> nrmse(preds, target) + tensor(0.1701) + + Example:: + Multioutput normalized root mean squared error computation: + + >>> import torch + >>> from torchmetrics import NormalizedRootMeanSquaredError + >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]]) + >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]]) + >>> nrmse = NormalizedRootMeanSquaredError(num_outputs=2) + >>> nrmse(preds, target) + tensor([0.2981, 0.2222]) + + """ + + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = True + plot_lower_bound: float = 0.0 + + sum_squared_error: Tensor + total: Tensor + min_val: Tensor + max_val: Tensor + target_squared: Tensor + mean_val: Tensor + var_val: Tensor + + def __init__( + self, + normalization: Literal["mean", "range", "std", "l2"] = "mean", + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if normalization not in ("mean", "range", "std", "l2"): + raise ValueError( + f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2', but got {normalization}" + ) + self.normalization = normalization + + if not (isinstance(num_outputs, int) and num_outputs > 0): + raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}") + self.num_outputs = num_outputs + + self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=torch.zeros(num_outputs), dist_reduce_fx=None) + self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) + self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) + self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("target_squared", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + See `mean_squared_error_update` for details. + + """ + sum_squared_error, num_obs = _mean_squared_error_update(preds, target, self.num_outputs) + self.sum_squared_error += sum_squared_error + target = target.view(-1) if self.num_outputs == 1 else target + + # Update min and max and target squared + self.min_val = torch.minimum(target.min(dim=0).values, self.min_val) + self.max_val = torch.maximum(target.max(dim=0).values, self.max_val) + self.target_squared += (target**2).sum(dim=0) + + # Update mean and variance + new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs) + self.total += num_obs + new_var = ((target - new_mean) * (target - self.mean_val)).sum(dim=0) + self.mean_val = new_mean + self.var_val += new_var + + def compute(self) -> Tensor: + """Computes NRMSE over state. + + See `mean_squared_error_compute` for details. + + """ + if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1): + denom = _final_aggregation( + min_val=self.min_val, + max_val=self.max_val, + mean_val=self.mean_val, + var_val=self.var_val, + target_squared=self.target_squared, + total=self.total, + normalization=self.normalization, + ) + total = self.total.squeeze().sum(dim=0) + else: + if self.normalization == "mean": + denom = self.mean_val + elif self.normalization == "range": + denom = self.max_val - self.min_val + elif self.normalization == "std": + denom = torch.sqrt(self.var_val / self.total) + else: + denom = torch.sqrt(self.target_squared) + total = self.total + return _normalized_root_mean_squared_error_compute(self.sum_squared_error, total, denom) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/README.md b/tests/README.md index 7f5cbd4e98a..6fce25567ef 100644 --- a/tests/README.md +++ b/tests/README.md @@ -7,16 +7,16 @@ the following command in the root directory of the project: pip install . -r requirements/_devel.txt ``` -Then for windows users, to execute the tests (unit tests and integration tests) run the following command (will only run non-DDP tests): +Then for Windows users, to execute the tests (unit tests and integration tests) run the following command (will only run non-DDP tests): ```bash pytest tests/ ``` -For linux/Mac users you will need to provide the `-m` argument to indicate if `ddp` tests should also be executed: +For Linux/Mac users you will need to provide the `-m` argument to indicate if `ddp` tests should also be executed: ```bash -pytest -m DDP tests/ # to run only DDP tests +USE_PYTEST_POOL="1" pytest -m DDP tests/ # to run only DDP tests pytest -m "not DDP" tests/ # to run all tests except DDP tests ``` diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index f37e80e4d16..38c86817184 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -18,6 +18,7 @@ import numpy as np import pytest import torch +from permetrics.regression import RegressionMetric from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error @@ -29,6 +30,7 @@ mean_absolute_percentage_error, mean_squared_error, mean_squared_log_error, + normalized_root_mean_squared_error, weighted_mean_absolute_percentage_error, ) from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error @@ -39,6 +41,7 @@ MeanSquaredLogError, WeightedMeanAbsolutePercentageError, ) +from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from unittests import BATCH_SIZE, NUM_BATCHES, _Input @@ -114,66 +117,179 @@ def _reference_symmetric_mape( return np.average(output_errors, weights=multioutput) +def _reference_normalized_root_mean_squared_error( + y_true: np.ndarray, y_pred: np.ndarray, normalization: str = "mean", num_outputs: int = 1 +): + """Reference implementation of Normalized Root Mean Squared Error (NRMSE) metric.""" + if num_outputs == 1: + y_true = y_true.flatten() + y_pred = y_pred.flatten() + if normalization != "l2": + evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) + arg_mapping = {"mean": 1, "range": 2, "std": 4} + return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + # for l2 normalization we do not have a reference implementation + return np.sqrt(np.mean(np.square(y_true - y_pred), axis=0)) / np.linalg.norm(y_true, axis=0) + + def _reference_weighted_mean_abs_percentage_error(target, preds): + """Reference implementation of Weighted Mean Absolute Percentage Error (WMAPE) metric.""" return np.sum(np.abs(target - preds)) / np.sum(np.abs(target)) def _single_target_ref_wrapper(preds, target, sk_fn, metric_args): + """Reference implementation of single-target metrics.""" sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - res = sk_fn(sk_target, sk_preds) - - return math.sqrt(res) if (metric_args and "squared" in metric_args and not metric_args["squared"]) else res + if metric_args and "normalization" in metric_args: + res = sk_fn(sk_target, sk_preds, normalization=metric_args["normalization"]) + else: + res = sk_fn(sk_target, sk_preds) + if metric_args and "squared" in metric_args and not metric_args["squared"]: + res = math.sqrt(res) + return res def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): + """Reference implementation of multi-target metrics.""" sk_preds = preds.view(-1, NUM_TARGETS).numpy() sk_target = target.view(-1, NUM_TARGETS).numpy() sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} - res = sk_fn(sk_target, sk_preds, **sk_kwargs) - return math.sqrt(res) if (metric_args and "squared" in metric_args and not metric_args["squared"]) else res + if metric_args and "normalization" in metric_args: + res = sk_fn(sk_target, sk_preds, **metric_args) + else: + res = sk_fn(sk_target, sk_preds, **sk_kwargs) + if metric_args and "squared" in metric_args and not metric_args["squared"]: + res = math.sqrt(res) + return res @pytest.mark.parametrize( - "preds, target, ref_metric", + ("preds", "target", "ref_metric"), [ (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper), (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper), ], ) @pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn, metric_args", + ("metric_class", "metric_functional", "sk_fn", "metric_args"), [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": NUM_TARGETS}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {"num_outputs": NUM_TARGETS}), - (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), - ( + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}, id="mse_singleoutput" + ), + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}, id="rmse_singleoutput" + ), + pytest.param( + MeanSquaredError, + mean_squared_error, + sk_mean_squared_error, + {"squared": True, "num_outputs": NUM_TARGETS}, + id="mse_multioutput", + ), + pytest.param(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}, id="mae_singleoutput"), + pytest.param( + MeanAbsoluteError, + mean_absolute_error, + sk_mean_absolute_error, + {"num_outputs": NUM_TARGETS}, + id="mae_multioutput", + ), + pytest.param( + MeanAbsolutePercentageError, + mean_absolute_percentage_error, + sk_mean_abs_percentage_error, + {}, + id="mape_singleoutput", + ), + pytest.param( SymmetricMeanAbsolutePercentageError, symmetric_mean_absolute_percentage_error, _reference_symmetric_mape, {}, + id="symmetric_mean_absolute_percentage_error", ), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), - ( + pytest.param( + MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}, id="mean_squared_log_error" + ), + pytest.param( WeightedMeanAbsolutePercentageError, weighted_mean_absolute_percentage_error, _reference_weighted_mean_abs_percentage_error, {}, + id="weighted_mean_absolute_percentage_error", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "mean", "num_outputs": 1}, + id="nrmse_singleoutput_mean", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "range", "num_outputs": 1}, + id="nrmse_singleoutput_range", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "std", "num_outputs": 1}, + id="nrmse_singleoutput_std", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "l2", "num_outputs": 1}, + id="nrmse_multioutput_l2", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "mean", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_mean", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "range", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_range", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "std", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_std", + ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "l2", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_l2", ), ], ) class TestMeanError(MetricTester): """Test class for `MeanError` metric.""" + atol = 1e-5 + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_mean_error_class( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args, ddp ): """Test class implementation of metric.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_class_metric_test( ddp=ddp, preds=preds, @@ -187,6 +303,8 @@ def test_mean_error_functional( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args ): """Test functional implementation of metric.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_functional_metric_test( preds=preds, target=target, @@ -199,6 +317,8 @@ def test_mean_error_differentiability( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args ): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_differentiability_test( preds=preds, target=target, @@ -225,6 +345,10 @@ def test_mean_error_half_cpu(self, preds, target, ref_metric, metric_class, metr # WeightedMeanAbsolutePercentageError half + cpu does not work due to missing support in torch.clamp pytest.xfail("WeightedMeanAbsolutePercentageError metric does not support cpu + half precision") + if metric_class == NormalizedRootMeanSquaredError: + # NormalizedRootMeanSquaredError half + cpu does not work due to missing support in torch.sqrt + pytest.xfail("NormalizedRootMeanSquaredError metric does not support cpu + half precision") + self.run_precision_test_cpu(preds, target, metric_class, metric_functional) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @@ -234,10 +358,30 @@ def test_mean_error_half_gpu(self, preds, target, ref_metric, metric_class, metr @pytest.mark.parametrize( - "metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, MeanAbsolutePercentageError] + "metric_class", + [ + MeanSquaredError, + MeanAbsoluteError, + MeanSquaredLogError, + MeanAbsolutePercentageError, + NormalizedRootMeanSquaredError, + ], ) def test_error_on_different_shape(metric_class): """Test that error is raised on different shapes of input.""" metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) + + +@pytest.mark.parametrize( + ("metric_class", "arguments", "error_msg"), + [ + (MeanSquaredError, {"squared": "something"}, "Expected argument `squared` to be a boolean.*"), + (NormalizedRootMeanSquaredError, {"normalization": "something"}, "Argument `normalization` should be either.*"), + ], +) +def test_error_on_wrong_extra_args(metric_class, arguments, error_msg): + """Test that error is raised on wrong extra arguments.""" + with pytest.raises(ValueError, match=error_msg): + metric_class(**arguments) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 5b85a01af5a..efb7077682e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -130,6 +130,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -469,6 +470,7 @@ pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"), pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"), pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), + pytest.param(NormalizedRootMeanSquaredError, _rand_input, _rand_input, id="normalized root mean squared error"), pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), pytest.param(RelativeSquaredError, _rand_input, _rand_input, id="relative squared error"),