From a3ed72ffa9efe4199848b2930c7b4086a886f4f3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 27 Feb 2024 20:01:27 +0100 Subject: [PATCH 01/10] ci: fix PR's file changed limit (#2412) --- .azure/gpu-integrations.yml | 2 +- .azure/gpu-unittests.yml | 1 + .github/actions/pull-caches/action.yml | 2 +- .github/assistant.py | 28 +++++--------------------- .github/workflows/_focus-diff.yml | 2 +- requirements/_tests.txt | 2 +- 6 files changed, 10 insertions(+), 27 deletions(-) diff --git a/.azure/gpu-integrations.yml b/.azure/gpu-integrations.yml index deff65423dc..024cc229b66 100644 --- a/.azure/gpu-integrations.yml +++ b/.azure/gpu-integrations.yml @@ -50,7 +50,7 @@ jobs: echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm" echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html" # packages for running assistant - pip install -q packaging fire requests wget + pip install -q fire wget packaging displayName: "set Env. vars" - bash: | diff --git a/.azure/gpu-unittests.yml b/.azure/gpu-unittests.yml index 3a59a1385ee..2013ab871b7 100644 --- a/.azure/gpu-unittests.yml +++ b/.azure/gpu-unittests.yml @@ -73,6 +73,7 @@ jobs: displayName: "set Env. vars for PRs" - bash: | + pip install -q fire pyGithub printf "PR: $PR_NUMBER \n" focus=$(python .github/assistant.py changed-domains $PR_NUMBER) printf "focus: $focus \n" diff --git a/.github/actions/pull-caches/action.yml b/.github/actions/pull-caches/action.yml index b03ac913a9d..a5cf7cafe2d 100644 --- a/.github/actions/pull-caches/action.yml +++ b/.github/actions/pull-caches/action.yml @@ -23,7 +23,7 @@ runs: using: "composite" steps: - name: install assistant's deps - run: pip install -q fire requests packaging wget + run: pip install -q packaging fire wget shell: bash - name: Set PyTorch version diff --git a/.github/assistant.py b/.github/assistant.py index f9eb521bec6..95652888022 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import glob -import json import logging import os import re import sys -import traceback from typing import List, Optional, Tuple, Union import fire -import requests from packaging.version import parse from pkg_resources import parse_requirements @@ -38,19 +35,6 @@ REQUIREMENTS_FILES = (*glob.glob(_path("requirements", "*.txt")), _path("requirements.txt")) -def request_url(url: str, auth_token: Optional[str] = None) -> Optional[dict]: - """General request with checking if request limit was reached.""" - auth_header = {"Authorization": f"token {auth_token}"} if auth_token else {} - try: - req = requests.get(url, headers=auth_header, timeout=_REQUEST_TIMEOUT) - except requests.exceptions.Timeout: - traceback.print_exc() - return None - if req.status_code == 403: - return None - return json.loads(req.content.decode(req.encoding)) - - class AssistantCLI: """CLI assistant for local CI.""" @@ -114,15 +98,13 @@ def changed_domains( general_sub_pkgs: Tuple[str] = _PKG_WIDE_SUBPACKAGES, ) -> Union[str, List[str]]: """Determine what domains were changed in particular PR.""" + import github + if not pr: return "unittests" - url = f"https://api.github.com/repos/Lightning-AI/torchmetrics/pulls/{pr}/files" - logging.debug(url) - data = request_url(url, auth_token) - if not data: - logging.debug("WARNING: No data was received -> test everything.") - return "unittests" - files = [d["filename"] for d in data] + gh = github.Github() + pr = gh.get_repo("Lightning-AI/torchmetrics").get_pull(pr) + files = [f.filename for f in pr.get_files()] # filter out all integrations as they run in separate suit files = [fn for fn in files if not fn.startswith("tests/integrations")] diff --git a/.github/workflows/_focus-diff.yml b/.github/workflows/_focus-diff.yml index 246eff24c1a..fbc8cab7d33 100644 --- a/.github/workflows/_focus-diff.yml +++ b/.github/workflows/_focus-diff.yml @@ -27,7 +27,7 @@ jobs: PR_NUMBER: "${{ github.event.pull_request.number }}" run: | echo $PR_NUMBER - pip install fire requests + pip install -q packaging fire pyGithub # python .github/assistant.py changed-domains $PR_NUMBER echo "focus=$(python .github/assistant.py changed-domains $PR_NUMBER)" >> $GITHUB_OUTPUT diff --git a/requirements/_tests.txt b/requirements/_tests.txt index cc1f1c1c60a..a30c04a3135 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -11,7 +11,7 @@ pytest-xdist ==3.5.0 phmdoctest ==1.4.0 psutil <5.10.0 -requests <=2.31.0 +pyGithub ==2.2.0 fire <=0.5.0 cloudpickle >1.3, <=3.0.0 From efb3a25ee7d729a57ca55de89b5e515452c099b8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 27 Feb 2024 20:14:38 +0100 Subject: [PATCH 02/10] ci: hotfix PRs' diff for GH actions (#2413) --- .github/workflows/_focus-diff.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_focus-diff.yml b/.github/workflows/_focus-diff.yml index fbc8cab7d33..faa4812f375 100644 --- a/.github/workflows/_focus-diff.yml +++ b/.github/workflows/_focus-diff.yml @@ -10,7 +10,7 @@ on: jobs: eval-diff: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest timeout-minutes: 5 # Map the job outputs to step outputs outputs: @@ -26,8 +26,9 @@ jobs: env: PR_NUMBER: "${{ github.event.pull_request.number }}" run: | + set -e echo $PR_NUMBER - pip install -q packaging fire pyGithub + pip install -q -U packaging fire pyGithub pyopenssl # python .github/assistant.py changed-domains $PR_NUMBER echo "focus=$(python .github/assistant.py changed-domains $PR_NUMBER)" >> $GITHUB_OUTPUT From c53ea94e8fa001a9e389432f5c76aa8c33b27222 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 27 Feb 2024 23:05:26 +0100 Subject: [PATCH 03/10] ci: cache reference metrics & clean audio tests (#2335) * cache reference metrics * audio * classif * regress * image * others * cleaning --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .gitignore | 3 +- tests/unittests/audio/__init__.py | 23 +++++++ tests/unittests/audio/test_pesq.py | 23 ++----- tests/unittests/audio/test_pit.py | 68 +++++++++---------- tests/unittests/audio/test_sa_sdr.py | 26 ++++--- tests/unittests/audio/test_sdr.py | 44 ++++++------ tests/unittests/audio/test_si_sdr.py | 22 ++---- tests/unittests/audio/test_si_snr.py | 21 +++--- tests/unittests/audio/test_snr.py | 20 ++---- tests/unittests/audio/test_srmr.py | 22 +++--- tests/unittests/audio/test_stoi.py | 26 ++----- tests/unittests/bases/test_aggregation.py | 46 ++++++------- .../classification/{inputs.py => _inputs.py} | 0 .../unittests/classification/test_accuracy.py | 38 +++++------ tests/unittests/classification/test_auroc.py | 26 +++---- .../classification/test_average_precision.py | 32 +++++---- .../classification/test_calibration_error.py | 14 ++-- .../classification/test_cohen_kappa.py | 18 +++-- .../classification/test_confusion_matrix.py | 24 ++++--- tests/unittests/classification/test_dice.py | 30 ++++---- .../classification/test_exact_match.py | 14 ++-- tests/unittests/classification/test_f_beta.py | 34 ++++++---- .../classification/test_group_fairness.py | 10 +-- .../classification/test_hamming_distance.py | 54 +++++++-------- tests/unittests/classification/test_hinge.py | 14 ++-- .../unittests/classification/test_jaccard.py | 30 +++++--- .../classification/test_matthews_corrcoef.py | 20 +++--- .../test_precision_fixed_recall.py | 30 +++++--- .../classification/test_precision_recall.py | 28 ++++---- .../test_precision_recall_curve.py | 22 +++--- .../unittests/classification/test_ranking.py | 8 +-- .../test_recall_fixed_precision.py | 36 +++++++--- tests/unittests/classification/test_roc.py | 22 +++--- .../test_sensitivity_specificity.py | 28 +++++--- .../classification/test_specificity.py | 36 +++++----- .../test_specificity_sensitivity.py | 28 +++++--- .../classification/test_stat_scores.py | 28 ++++---- .../clustering/{inputs.py => _inputs.py} | 0 .../test_adjusted_mutual_info_score.py | 2 +- .../clustering/test_adjusted_rand_score.py | 2 +- .../test_calinski_harabasz_score.py | 2 +- .../clustering/test_davies_bouldin_score.py | 2 +- tests/unittests/clustering/test_dunn_index.py | 8 +-- .../clustering/test_fowlkes_mallows_index.py | 2 +- ...test_homogeneity_completeness_v_measure.py | 8 +-- .../clustering/test_mutual_info_score.py | 2 +- .../test_normalized_mutual_info_score.py | 2 +- tests/unittests/clustering/test_rand_score.py | 2 +- .../test_modified_panoptic_quality.py | 16 ++--- .../detection/test_panoptic_quality.py | 16 ++--- tests/unittests/image/test_csi.py | 6 +- tests/unittests/image/test_d_lambda.py | 24 ++----- tests/unittests/image/test_d_s.py | 4 +- tests/unittests/image/test_ergas.py | 10 +-- tests/unittests/image/test_lpips.py | 10 +-- tests/unittests/image/test_mifid.py | 7 +- tests/unittests/image/test_ms_ssim.py | 11 ++- tests/unittests/image/test_psnr.py | 18 ++--- tests/unittests/image/test_psnrb.py | 13 +++- tests/unittests/image/test_qnr.py | 60 ++++++---------- tests/unittests/image/test_rase.py | 10 +-- tests/unittests/image/test_rmse_sw.py | 10 +-- tests/unittests/image/test_sam.py | 14 ++-- tests/unittests/image/test_ssim.py | 26 +++---- tests/unittests/image/test_tv.py | 14 ++-- tests/unittests/image/test_uqi.py | 10 +-- tests/unittests/image/test_vif.py | 10 ++- tests/unittests/multimodal/test_clip_iqa.py | 10 +-- tests/unittests/multimodal/test_clip_score.py | 6 +- tests/unittests/nominal/test_cramers.py | 8 +-- tests/unittests/nominal/test_fleiss_kappa.py | 6 +- tests/unittests/nominal/test_pearson.py | 12 ++-- tests/unittests/nominal/test_theils_u.py | 14 ++-- tests/unittests/nominal/test_tschuprows.py | 16 +++-- .../unittests/regression/test_concordance.py | 6 +- .../regression/test_cosine_similarity.py | 12 ++-- .../regression/test_explained_variance.py | 10 +-- tests/unittests/regression/test_kendall.py | 6 +- .../regression/test_log_cosh_error.py | 16 ++--- tests/unittests/regression/test_mean_error.py | 28 ++++---- .../regression/test_minkowski_distance.py | 14 ++-- tests/unittests/regression/test_pearson.py | 6 +- tests/unittests/regression/test_r2.py | 18 ++--- tests/unittests/regression/test_rse.py | 24 +++---- tests/unittests/regression/test_spearman.py | 6 +- .../regression/test_tweedie_deviance.py | 6 +- .../retrieval/{inputs.py => _inputs.py} | 0 tests/unittests/retrieval/helpers.py | 24 +++---- .../unittests/text/{inputs.py => _inputs.py} | 0 tests/unittests/text/test_bertscore.py | 4 +- tests/unittests/text/test_bleu.py | 2 +- tests/unittests/text/test_cer.py | 2 +- tests/unittests/text/test_chrf.py | 2 +- tests/unittests/text/test_edit.py | 2 +- tests/unittests/text/test_eed.py | 2 +- tests/unittests/text/test_infolm.py | 2 +- tests/unittests/text/test_mer.py | 2 +- tests/unittests/text/test_perplexity.py | 2 +- tests/unittests/text/test_rouge.py | 2 +- tests/unittests/text/test_sacre_bleu.py | 14 ++-- tests/unittests/text/test_squad.py | 2 +- tests/unittests/text/test_ter.py | 2 +- tests/unittests/text/test_wer.py | 8 +-- tests/unittests/text/test_wil.py | 2 +- tests/unittests/text/test_wip.py | 2 +- 105 files changed, 809 insertions(+), 780 deletions(-) rename tests/unittests/classification/{inputs.py => _inputs.py} (100%) rename tests/unittests/clustering/{inputs.py => _inputs.py} (100%) rename tests/unittests/retrieval/{inputs.py => _inputs.py} (100%) rename tests/unittests/text/{inputs.py => _inputs.py} (100%) diff --git a/.gitignore b/.gitignore index 7f4c97f66db..6f45b493e3c 100644 --- a/.gitignore +++ b/.gitignore @@ -40,9 +40,8 @@ pip-delete-this-directory.txt # Unit test / coverage reports tests/_data/ data.zip +tests/_reference-cache/ htmlcov/ -.tox/ -.nox/ .coverage .coverage.* .cache diff --git a/tests/unittests/audio/__init__.py b/tests/unittests/audio/__init__.py index 791a7a86359..dd8a3c0014f 100644 --- a/tests/unittests/audio/__init__.py +++ b/tests/unittests/audio/__init__.py @@ -1,7 +1,30 @@ import os +from typing import Callable, Optional + +from torch import Tensor from unittests import _PATH_ALL_TESTS _SAMPLE_AUDIO_SPEECH = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech.wav") _SAMPLE_AUDIO_SPEECH_BAB_DB = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech_bab_0dB.wav") _SAMPLE_NUMPY_ISSUE_895 = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "issue_895.npz") + + +def _average_metric_wrapper( + preds: Tensor, target: Tensor, metric_func: Callable, res_index: Optional[int] = None +) -> Tensor: + """Average the metric values. + + Args: + preds: predictions, shape[batch, spk, time] + target: targets, shape[batch, spk, time] + metric_func: a function which return best_metric and best_perm + res_index: if not None, return best_metric[res_index] + + Returns: + the average of best_metric + + """ + if res_index is not None: + return metric_func(preds, target)[res_index].mean() + return metric_func(preds, target).mean() diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index 8f30396c5d2..c10cfef6568 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -22,7 +22,7 @@ from torchmetrics.functional.audio import perceptual_evaluation_speech_quality from unittests import _Input -from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB +from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -41,7 +41,7 @@ ) -def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str): +def _reference_pesq_batch(preds: Tensor, target: Tensor, fs: int, mode: str): """Comparison function.""" # shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time] @@ -54,23 +54,12 @@ def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str): return torch.tensor(mss) -def _average_metric(preds, target, metric_func): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -pesq_original_batch_8k_nb = partial(_pesq_original_batch, fs=8000, mode="nb") -pesq_original_batch_16k_nb = partial(_pesq_original_batch, fs=16000, mode="nb") -pesq_original_batch_16k_wb = partial(_pesq_original_batch, fs=16000, mode="wb") - - @pytest.mark.parametrize( "preds, target, ref_metric, fs, mode", [ - (inputs_8k.preds, inputs_8k.target, pesq_original_batch_8k_nb, 8000, "nb"), - (inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_nb, 16000, "nb"), - (inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_wb, 16000, "wb"), + (inputs_8k.preds, inputs_8k.target, partial(_reference_pesq_batch, fs=8000, mode="nb"), 8000, "nb"), + (inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="nb"), 16000, "nb"), + (inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="wb"), 16000, "wb"), ], ) class TestPESQ(MetricTester): @@ -89,7 +78,7 @@ def test_pesq(self, preds, target, ref_metric, fs, mode, num_processes, ddp): preds, target, PerceptualEvaluationSpeechQuality, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric), metric_args={"fs": fs, "mode": mode, "n_processes": num_processes}, ) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index ba13be57a10..107a775e728 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -31,27 +31,28 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests.audio import _average_metric_wrapper from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -TIME = 10 +TIME_FRAME = 10 # three speaker examples to test _find_best_perm_by_linear_sum_assignment inputs1 = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME), ) # two speaker examples to test _find_best_perm_by_exhuastive_method inputs2 = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME), ) -def naive_implementation_pit_scipy( +def _reference_scipy_pit( preds: Tensor, target: Tensor, metric_func: Callable, @@ -66,10 +67,8 @@ def naive_implementation_pit_scipy( eval_func: min or max Returns: - best_metric: - shape [batch] - best_perm: - shape [batch, spk] + best_metric: shape [batch] + best_perm: shape [batch, spk] """ batch_size, spk_num = target.shape[0:2] @@ -88,62 +87,59 @@ def naive_implementation_pit_scipy( return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms)) -def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor: - """Average the metric values. +def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + return _reference_scipy_pit( + preds=preds, + target=target, + metric_func=signal_noise_ratio, + eval_func="max", + ) - Args: - preds: predictions, shape[batch, spk, time] - target: targets, shape[batch, spk, time] - metric_func: a function which return best_metric and best_perm - - Returns: - the average of best_metric - """ - return metric_func(preds, target)[0].mean() - - -snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=signal_noise_ratio, eval_func="max") -si_sdr_pit_scipy = partial( - naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max" -) +def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + return _reference_scipy_pit( + preds=preds, + target=target, + metric_func=scale_invariant_signal_distortion_ratio, + eval_func="max", + ) @pytest.mark.parametrize( "preds, target, ref_metric, metric_func, mode, eval_func", [ - (inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"), + (inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"), ( inputs1.preds, inputs1.target, - si_sdr_pit_scipy, + _reference_scipy_pit_si_sdr, scale_invariant_signal_distortion_ratio, "speaker-wise", "max", ), - (inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"), + (inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"), ( inputs2.preds, inputs2.target, - si_sdr_pit_scipy, + _reference_scipy_pit_si_sdr, scale_invariant_signal_distortion_ratio, "speaker-wise", "max", ), - (inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"), + (inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"), ( inputs1.preds, inputs1.target, - si_sdr_pit_scipy, + _reference_scipy_pit_si_sdr, scale_invariant_signal_distortion_ratio, "permutation-wise", "max", ), - (inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"), + (inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"), ( inputs2.preds, inputs2.target, - si_sdr_pit_scipy, + _reference_scipy_pit_si_sdr, scale_invariant_signal_distortion_ratio, "permutation-wise", "max", @@ -163,7 +159,7 @@ def test_pit(self, preds, target, ref_metric, metric_func, mode, eval_func, ddp) preds, target, PermutationInvariantTraining, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric, res_index=0), metric_args={"metric_func": metric_func, "mode": mode, "eval_func": eval_func}, ) diff --git a/tests/unittests/audio/test_sa_sdr.py b/tests/unittests/audio/test_sa_sdr.py index c852e0959a5..d6e7178b8ad 100644 --- a/tests/unittests/audio/test_sa_sdr.py +++ b/tests/unittests/audio/test_sa_sdr.py @@ -38,7 +38,9 @@ ) -def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool): +def _reference_local_sa_sdr( + preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool, reduce_mean: bool = False +): # According to the original paper, the sa-sdr equals to si-sdr with inputs concatenated over the speaker # dimension if scale_invariant==True. Accordingly, for scale_invariant==False, the sa-sdr equals to snr. # shape: preds [BATCH_SIZE, Spk, Time] , target [BATCH_SIZE, Spk, Time] @@ -51,14 +53,14 @@ def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: preds = preds.reshape(preds.shape[0], preds.shape[1] * preds.shape[2]) target = target.reshape(target.shape[0], target.shape[1] * target.shape[2]) if scale_invariant: - return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False) - return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean) - - -def _average_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return _ref_metric(preds, target, scale_invariant, zero_mean).mean() + sa_sdr = scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False) + else: + sa_sdr = signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean) + if reduce_mean: + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return sa_sdr.mean() + return sa_sdr @pytest.mark.parametrize( @@ -83,7 +85,9 @@ def test_si_sdr(self, preds, target, scale_invariant, zero_mean, ddp): preds, target, SourceAggregatedSignalDistortionRatio, - reference_metric=partial(_average_metric, scale_invariant=scale_invariant, zero_mean=zero_mean), + reference_metric=partial( + _reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean, reduce_mean=True + ), metric_args={ "scale_invariant": scale_invariant, "zero_mean": zero_mean, @@ -96,7 +100,7 @@ def test_sa_sdr_functional(self, preds, target, scale_invariant, zero_mean): preds, target, source_aggregated_signal_distortion_ratio, - reference_metric=partial(_ref_metric, scale_invariant=scale_invariant, zero_mean=zero_mean), + reference_metric=partial(_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean), metric_args={ "scale_invariant": scale_invariant, "zero_mean": zero_mean, diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index 80b1ff49764..ce8756d8ec7 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable import numpy as np import pytest @@ -43,7 +42,9 @@ ) -def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool = False) -> Tensor: +def _reference_sdr_batch( + preds: Tensor, target: Tensor, compute_permutation: bool = False, reduce_mean: bool = False +) -> Tensor: # shape: preds [BATCH_SIZE, spk, Time] , target [BATCH_SIZE, spk, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, spk, Time] , target [NUM_BATCHES*BATCH_SIZE, spk, Time] target = target.detach().cpu().numpy() @@ -52,27 +53,20 @@ def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool for b in range(preds.shape[0]): sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation) mss.append(sdr_val_np) - return torch.tensor(np.array(mss)) + sdr = torch.tensor(np.array(mss)) + if reduce_mean: + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return sdr.mean() + return sdr -def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor: - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -original_impl_compute_permutation = partial(_sdr_original_batch) - - -@pytest.mark.skipif( # TODO: figure out why tests leads to cuda errors on latest torch +@pytest.mark.skipif( # FIXME: figure out why tests leads to cuda errors on latest torch _TORCH_GREATER_EQUAL_1_11 and torch.cuda.is_available(), reason="tests leads to cuda errors on latest torch" ) @pytest.mark.parametrize( - "preds, target, ref_metric", - [ - (inputs_1spk.preds, inputs_1spk.target, original_impl_compute_permutation), - (inputs_2spk.preds, inputs_2spk.target, original_impl_compute_permutation), - ], + "preds, target", + [(inputs_1spk.preds, inputs_1spk.target), (inputs_2spk.preds, inputs_2spk.target)], ) class TestSDR(MetricTester): """Test class for `SignalDistortionRatio` metric.""" @@ -80,28 +74,28 @@ class TestSDR(MetricTester): atol = 1e-2 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_sdr(self, preds, target, ref_metric, ddp): + def test_sdr(self, preds, target, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp, preds, target, SignalDistortionRatio, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_reference_sdr_batch, reduce_mean=True), metric_args={}, ) - def test_sdr_functional(self, preds, target, ref_metric): + def test_sdr_functional(self, preds, target): """Test functional implementation of metric.""" self.run_functional_metric_test( preds, target, signal_distortion_ratio, - ref_metric, + _reference_sdr_batch, metric_args={}, ) - def test_sdr_differentiability(self, preds, target, ref_metric): + def test_sdr_differentiability(self, preds, target): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test( preds=preds, @@ -110,7 +104,7 @@ def test_sdr_differentiability(self, preds, target, ref_metric): metric_args={}, ) - def test_sdr_half_cpu(self, preds, target, ref_metric): + def test_sdr_half_cpu(self, preds, target): """Test dtype support of the metric on CPU.""" self.run_precision_test_cpu( preds=preds, @@ -121,7 +115,7 @@ def test_sdr_half_cpu(self, preds, target, ref_metric): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_sdr_half_gpu(self, preds, target, ref_metric): + def test_sdr_half_gpu(self, preds, target): """Test dtype support of the metric on GPU.""" self.run_precision_test_gpu( preds=preds, diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index d314219542b..6f014f828eb 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -21,6 +21,7 @@ from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests.audio import _average_metric_wrapper from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -34,10 +35,9 @@ target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES), ) -speechmetrics_sisdr = speechmetrics.load("sisdr") - -def _speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): +def _reference_speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): + speechmetrics_sisdr = speechmetrics.load("sisdr") # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: @@ -55,21 +55,11 @@ def _speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool): return torch.tensor(mss) -def _average_metric(preds, target, metric_func): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -speechmetrics_si_sdr_zero_mean = partial(_speechmetrics_si_sdr, zero_mean=True) -speechmetrics_si_sdr_no_zero_mean = partial(_speechmetrics_si_sdr, zero_mean=False) - - @pytest.mark.parametrize( "preds, target, ref_metric, zero_mean", [ - (inputs.preds, inputs.target, speechmetrics_si_sdr_zero_mean, True), - (inputs.preds, inputs.target, speechmetrics_si_sdr_no_zero_mean, False), + (inputs.preds, inputs.target, partial(_reference_speechmetrics_si_sdr, zero_mean=True), True), + (inputs.preds, inputs.target, partial(_reference_speechmetrics_si_sdr, zero_mean=False), False), ], ) class TestSISDR(MetricTester): @@ -85,7 +75,7 @@ def test_si_sdr(self, preds, target, ref_metric, zero_mean, ddp): preds, target, ScaleInvariantSignalDistortionRatio, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric), metric_args={"zero_mean": zero_mean}, ) diff --git a/tests/unittests/audio/test_si_snr.py b/tests/unittests/audio/test_si_snr.py index 8cd28afb258..f6f6c7f52f7 100644 --- a/tests/unittests/audio/test_si_snr.py +++ b/tests/unittests/audio/test_si_snr.py @@ -37,7 +37,7 @@ speechmetrics_sisdr = speechmetrics.load("sisdr") -def _speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True): +def _reference_speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True, reduce_mean: bool = False): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: @@ -52,20 +52,17 @@ def _speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True) metric = speechmetrics_sisdr(preds[i, j], target[i, j], rate=16000) ms.append(metric["sisdr"][0]) mss.append(ms) - return torch.tensor(mss) - - -def _average_metric(preds, target, metric_func): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() + si_sdr = torch.tensor(mss) + if reduce_mean: + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return si_sdr.mean() + return si_sdr @pytest.mark.parametrize( "preds, target, ref_metric", - [ - (inputs.preds, inputs.target, _speechmetrics_si_sdr), - ], + [(inputs.preds, inputs.target, _reference_speechmetrics_si_sdr)], ) class TestSISNR(MetricTester): """Test class for `ScaleInvariantSignalNoiseRatio` metric.""" @@ -80,7 +77,7 @@ def test_si_snr(self, preds, target, ref_metric, ddp): preds, target, ScaleInvariantSignalNoiseRatio, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_reference_speechmetrics_si_sdr, reduce_mean=True), ) def test_si_snr_functional(self, preds, target, ref_metric): diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index ad3896ed3fc..d513360851a 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable import pytest import torch @@ -22,6 +21,7 @@ from torchmetrics.functional.audio import signal_noise_ratio from unittests import _Input +from unittests.audio import _average_metric_wrapper from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -34,7 +34,7 @@ ) -def _bss_eval_images_snr(preds: Tensor, target: Tensor, zero_mean: bool): +def _reference_bss_snr(preds: Tensor, target: Tensor, zero_mean: bool): # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] if zero_mean: @@ -52,21 +52,11 @@ def _bss_eval_images_snr(preds: Tensor, target: Tensor, zero_mean: bool): return torch.tensor(mss) -def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -mireval_snr_zeromean = partial(_bss_eval_images_snr, zero_mean=True) -mireval_snr_nozeromean = partial(_bss_eval_images_snr, zero_mean=False) - - @pytest.mark.parametrize( "preds, target, ref_metric, zero_mean", [ - (inputs.preds, inputs.target, mireval_snr_zeromean, True), - (inputs.preds, inputs.target, mireval_snr_nozeromean, False), + (inputs.preds, inputs.target, partial(_reference_bss_snr, zero_mean=True), True), + (inputs.preds, inputs.target, partial(_reference_bss_snr, zero_mean=False), False), ], ) class TestSNR(MetricTester): @@ -82,7 +72,7 @@ def test_snr(self, preds, target, ref_metric, zero_mean, ddp): preds, target, SignalNoiseRatio, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric), metric_args={"zero_mean": zero_mean}, ) diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index e8f3c3a59b7..fa0ce989a61 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -30,7 +30,9 @@ preds = torch.rand(2, 2, 8000) -def _ref_metric_batch(preds: Tensor, target: Tensor, fs: int, fast: bool, norm: bool, **kwargs: Dict[str, Any]): +def _reference_srmr_batch( + preds: Tensor, target: Tensor, fs: int, fast: bool, norm: bool, reduce_mean: bool = False, **kwargs: Dict[str, Any] +): # shape: preds [BATCH_SIZE, Time] shape = preds.shape preds = preds.reshape(1, -1) if len(shape) == 1 else preds.reshape(-1, shape[-1]) @@ -42,13 +44,12 @@ def _ref_metric_batch(preds: Tensor, target: Tensor, fs: int, fast: bool, norm: val, _ = srmrpy_srmr(preds[b, ...], fs=fs, fast=fast, norm=norm, max_cf=128 if not norm else 30) score.append(val) score = torch.tensor(score) - return score.reshape(*shape[:-1]) - - -def _average_metric(preds, target, metric_func, **kwargs: Dict[str, Any]): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target, **kwargs).mean() + srmr = score.reshape(*shape[:-1]) + if reduce_mean: + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return srmr.mean() + return srmr def _speech_reverberation_modulation_energy_ratio_cheat(preds, target, **kwargs: Dict[str, Any]): @@ -62,6 +63,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: super().update(preds=preds) +# FIXME: bring compatibility with torchaudio 0.10+ @pytest.mark.skipif(not _TORCHAUDIO_GREATER_EQUAL_0_10, reason="torchaudio>=0.10.0 is required") @pytest.mark.parametrize( "preds, fs, fast, norm", @@ -89,7 +91,7 @@ def test_srmr(self, preds, fs, fast, norm, ddp): preds=preds, target=preds, metric_class=_SpeechReverberationModulationEnergyRatioCheat, - reference_metric=partial(_average_metric, metric_func=_ref_metric_batch, fs=fs, fast=fast, norm=norm), + reference_metric=partial(_reference_srmr_batch, fs=fs, fast=fast, norm=norm, reduce_mean=True), metric_args={"fs": fs, "fast": fast, "norm": norm}, ) @@ -99,7 +101,7 @@ def test_srmr_functional(self, preds, fs, fast, norm): preds=preds, target=preds, metric_functional=_speech_reverberation_modulation_energy_ratio_cheat, - reference_metric=partial(_ref_metric_batch, fs=fs, fast=fast, norm=norm), + reference_metric=partial(_reference_srmr_batch, fs=fs, fast=fast, norm=norm), metric_args={"fs": fs, "fast": fast, "norm": norm}, ) diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index 4dc99754a48..d05824a1380 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -22,7 +22,7 @@ from torchmetrics.functional.audio import short_time_objective_intelligibility from unittests import _Input -from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB +from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -39,7 +39,7 @@ ) -def _stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool): +def _reference_stoi_batch(preds: Tensor, target: Tensor, fs: int, extended: bool): # shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time] # or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time] target = target.detach().cpu().numpy() @@ -51,25 +51,13 @@ def _stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool) return torch.tensor(mss) -def _average_metric(preds, target, metric_func): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] - return metric_func(preds, target).mean() - - -stoi_original_batch_8k_ext = partial(_stoi_original_batch, fs=8000, extended=True) -stoi_original_batch_16k_ext = partial(_stoi_original_batch, fs=16000, extended=True) -stoi_original_batch_8k_noext = partial(_stoi_original_batch, fs=8000, extended=False) -stoi_original_batch_16k_noext = partial(_stoi_original_batch, fs=16000, extended=False) - - @pytest.mark.parametrize( "preds, target, ref_metric, fs, extended", [ - (inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_ext, 8000, True), - (inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_ext, 16000, True), - (inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_noext, 8000, False), - (inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_noext, 16000, False), + (inputs_8k.preds, inputs_8k.target, partial(_reference_stoi_batch, fs=8000, extended=True), 8000, True), + (inputs_16k.preds, inputs_16k.target, partial(_reference_stoi_batch, fs=16000, extended=True), 16000, True), + (inputs_8k.preds, inputs_8k.target, partial(_reference_stoi_batch, fs=8000, extended=False), 8000, False), + (inputs_16k.preds, inputs_16k.target, partial(_reference_stoi_batch, fs=16000, extended=False), 16000, False), ], ) class TestSTOI(MetricTester): @@ -85,7 +73,7 @@ def test_stoi(self, preds, target, ref_metric, fs, extended, ddp): preds, target, ShortTimeObjectiveIntelligibility, - reference_metric=partial(_average_metric, metric_func=ref_metric), + reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric), metric_args={"fs": fs, "extended": extended}, ) diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index 5a65eaa3fb4..25ff2233076 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -96,11 +96,11 @@ def test_aggreagation(self, ddp, metric_class, compare_fn, values, weights): ) -_case1 = float("nan") * torch.ones(5) -_case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) +_CASE_1 = float("nan") * torch.ones(5) +_CASE_2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) -@pytest.mark.parametrize("value", [_case1, _case2]) +@pytest.mark.parametrize("value", [_CASE_1, _CASE_2]) @pytest.mark.parametrize("nan_strategy", ["error", "warn"]) @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_nan_error(value, nan_strategy, metric_class): @@ -117,26 +117,26 @@ def test_nan_error(value, nan_strategy, metric_class): @pytest.mark.parametrize( ("metric_class", "nan_strategy", "value", "expected"), [ - (MinMetric, "ignore", _case1, torch.tensor(float("inf"))), - (MinMetric, 2.0, _case1, 2.0), - (MinMetric, "ignore", _case2, 1.0), - (MinMetric, 2.0, _case2, 1.0), - (MaxMetric, "ignore", _case1, -torch.tensor(float("inf"))), - (MaxMetric, 2.0, _case1, 2.0), - (MaxMetric, "ignore", _case2, 5.0), - (MaxMetric, 2.0, _case2, 5.0), - (SumMetric, "ignore", _case1, 0.0), - (SumMetric, 2.0, _case1, 10.0), - (SumMetric, "ignore", _case2, 12.0), - (SumMetric, 2.0, _case2, 14.0), - (MeanMetric, "ignore", _case1, torch.tensor([float("nan")])), - (MeanMetric, 2.0, _case1, 2.0), - (MeanMetric, "ignore", _case2, 3.0), - (MeanMetric, 2.0, _case2, 2.8), - (CatMetric, "ignore", _case1, []), - (CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), - (CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), - (CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), + (MinMetric, "ignore", _CASE_1, torch.tensor(float("inf"))), + (MinMetric, 2.0, _CASE_1, 2.0), + (MinMetric, "ignore", _CASE_2, 1.0), + (MinMetric, 2.0, _CASE_2, 1.0), + (MaxMetric, "ignore", _CASE_1, -torch.tensor(float("inf"))), + (MaxMetric, 2.0, _CASE_1, 2.0), + (MaxMetric, "ignore", _CASE_2, 5.0), + (MaxMetric, 2.0, _CASE_2, 5.0), + (SumMetric, "ignore", _CASE_1, 0.0), + (SumMetric, 2.0, _CASE_1, 10.0), + (SumMetric, "ignore", _CASE_2, 12.0), + (SumMetric, 2.0, _CASE_2, 14.0), + (MeanMetric, "ignore", _CASE_1, torch.tensor([float("nan")])), + (MeanMetric, 2.0, _CASE_1, 2.0), + (MeanMetric, "ignore", _CASE_2, 3.0), + (MeanMetric, 2.0, _CASE_2, 2.8), + (CatMetric, "ignore", _CASE_1, []), + (CatMetric, 2.0, _CASE_1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), + (CatMetric, "ignore", _CASE_2, torch.tensor([1.0, 2.0, 4.0, 5.0])), + (CatMetric, 2.0, _CASE_2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), (CatMetric, "ignore", torch.zeros(5), torch.zeros(5)), ], ) diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/_inputs.py similarity index 100% rename from tests/unittests/classification/inputs.py rename to tests/unittests/classification/_inputs.py diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index f2954f8b620..bd5182c9522 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -29,19 +29,19 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_accuracy(target, preds): +def _reference_sklearn_accuracy(target, preds): score = sk_accuracy(target, preds) return score if not np.isnan(score) else 0.0 -def _sklearn_accuracy_binary(preds, target, ignore_index, multidim_average): +def _reference_sklearn_accuracy_binary(preds, target, ignore_index, multidim_average): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -56,14 +56,14 @@ def _sklearn_accuracy_binary(preds, target, ignore_index, multidim_average): if multidim_average == "global": target, preds = remove_ignore_index(target, preds, ignore_index) - return _sklearn_accuracy(target, preds) + return _reference_sklearn_accuracy(target, preds) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(_sklearn_accuracy(true, pred)) + res.append(_reference_sklearn_accuracy(true, pred)) return np.stack(res) @@ -108,7 +108,7 @@ def test_binary_accuracy(self, ddp, inputs, ignore_index, multidim_average): target=target, metric_class=BinaryAccuracy, reference_metric=partial( - _sklearn_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) @@ -128,7 +128,7 @@ def test_binary_accuracy_functional(self, inputs, ignore_index, multidim_average target=target, metric_functional=binary_accuracy, reference_metric=partial( - _sklearn_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={ "threshold": THRESHOLD, @@ -179,7 +179,7 @@ def test_binary_accuracy_half_gpu(self, inputs, dtype): ) -def _sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, average): +def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) if multidim_average == "global": @@ -187,7 +187,7 @@ def _sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) if average == "micro": - return _sklearn_accuracy(target, preds) + return _reference_sklearn_accuracy(target, preds) confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES))) acc_per_class = confmat.diagonal() / confmat.sum(axis=1) acc_per_class[np.isnan(acc_per_class)] = 0.0 @@ -209,7 +209,7 @@ def _sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) if average == "micro": - res.append(_sklearn_accuracy(true, pred)) + res.append(_reference_sklearn_accuracy(true, pred)) else: confmat = sk_confusion_matrix(true, pred, labels=list(range(NUM_CLASSES))) acc_per_class = confmat.diagonal() / confmat.sum(axis=1) @@ -252,7 +252,7 @@ def test_multiclass_accuracy(self, ddp, inputs, ignore_index, multidim_average, target=target, metric_class=MulticlassAccuracy, reference_metric=partial( - _sklearn_accuracy_multiclass, + _reference_sklearn_accuracy_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -281,7 +281,7 @@ def test_multiclass_accuracy_functional(self, inputs, ignore_index, multidim_ave target=target, metric_functional=multiclass_accuracy, reference_metric=partial( - _sklearn_accuracy_multiclass, + _reference_sklearn_accuracy_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -355,7 +355,7 @@ def test_top_k(k, preds, target, average, expected): assert torch.isclose(multiclass_accuracy(preds, target, top_k=k, average=average, num_classes=3), expected) -def _sklearn_accuracy_multilabel(preds, target, ignore_index, multidim_average, average): +def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -370,14 +370,14 @@ def _sklearn_accuracy_multilabel(preds, target, ignore_index, multidim_average, preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return _sklearn_accuracy(target, preds) + return _reference_sklearn_accuracy(target, preds) accuracy, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - accuracy.append(_sklearn_accuracy(true, pred)) + accuracy.append(_reference_sklearn_accuracy(true, pred)) weights.append(confmat[1, 1] + confmat[1, 0]) res = np.stack(accuracy, axis=0) @@ -397,7 +397,7 @@ def _sklearn_accuracy_multilabel(preds, target, ignore_index, multidim_average, if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - accuracy.append(_sklearn_accuracy(true, pred)) + accuracy.append(_reference_sklearn_accuracy(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) else: @@ -405,7 +405,7 @@ def _sklearn_accuracy_multilabel(preds, target, ignore_index, multidim_average, for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(_sklearn_accuracy(true, pred)) + scores.append(_reference_sklearn_accuracy(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) accuracy.append(np.stack(scores)) @@ -449,7 +449,7 @@ def test_multilabel_accuracy(self, ddp, inputs, ignore_index, multidim_average, target=target, metric_class=MultilabelAccuracy, reference_metric=partial( - _sklearn_accuracy_multilabel, + _reference_sklearn_accuracy_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -479,7 +479,7 @@ def test_multilabel_accuracy_functional(self, inputs, ignore_index, multidim_ave target=target, metric_functional=multilabel_accuracy, reference_metric=partial( - _sklearn_accuracy_multilabel, + _reference_sklearn_accuracy_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index a6c30271388..2b043edbbe9 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -25,14 +25,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_auroc_binary(preds, target, max_fpr=None, ignore_index=None): +def _reference_sklearn_auroc_binary(preds, target, max_fpr=None, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if not ((preds > 0) & (preds < 1)).all(): @@ -58,7 +58,7 @@ def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): preds=preds, target=target, metric_class=BinaryAUROC, - reference_metric=partial(_sklearn_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), metric_args={ "max_fpr": max_fpr, "thresholds": None, @@ -77,7 +77,7 @@ def test_binary_auroc_functional(self, inputs, max_fpr, ignore_index): preds=preds, target=target, metric_functional=binary_auroc, - reference_metric=partial(_sklearn_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), metric_args={ "max_fpr": max_fpr, "thresholds": None, @@ -138,7 +138,7 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(ap1, ap2) -def _sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): +def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -166,7 +166,7 @@ def test_multiclass_auroc(self, inputs, average, ddp, ignore_index): preds=preds, target=target, metric_class=MulticlassAUROC, - reference_metric=partial(_sklearn_auroc_multiclass, average=average, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_auroc_multiclass, average=average, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -186,7 +186,7 @@ def test_multiclass_auroc_functional(self, inputs, average, ignore_index): preds=preds, target=target, metric_functional=multiclass_auroc, - reference_metric=partial(_sklearn_auroc_multiclass, average=average, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_auroc_multiclass, average=average, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -251,7 +251,7 @@ def test_multiclass_auroc_threshold_arg(self, inputs, average): assert torch.allclose(ap1, ap2) -def _sklearn_auroc_multilabel(preds, target, average="macro", ignore_index=None): +def _reference_sklearn_auroc_multilabel(preds, target, average="macro", ignore_index=None): if ignore_index is None: if preds.ndim > 2: target = target.transpose(2, 1).reshape(-1, NUM_CLASSES) @@ -262,9 +262,11 @@ def _sklearn_auroc_multilabel(preds, target, average="macro", ignore_index=None) preds = sigmoid(preds) return sk_roc_auc_score(target, preds, average=average, max_fpr=None) if average == "micro": - return _sklearn_auroc_binary(preds.flatten(), target.flatten(), max_fpr=None, ignore_index=ignore_index) + return _reference_sklearn_auroc_binary( + preds.flatten(), target.flatten(), max_fpr=None, ignore_index=ignore_index + ) res = [ - _sklearn_auroc_binary(preds[:, i], target[:, i], max_fpr=None, ignore_index=ignore_index) + _reference_sklearn_auroc_binary(preds[:, i], target[:, i], max_fpr=None, ignore_index=ignore_index) for i in range(NUM_CLASSES) ] if average == "macro": @@ -295,7 +297,7 @@ def test_multilabel_auroc(self, inputs, ddp, average, ignore_index): preds=preds, target=target, metric_class=MultilabelAUROC, - reference_metric=partial(_sklearn_auroc_multilabel, average=average, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_auroc_multilabel, average=average, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, @@ -315,7 +317,7 @@ def test_multilabel_auroc_functional(self, inputs, average, ignore_index): preds=preds, target=target, metric_functional=multilabel_auroc, - reference_metric=partial(_sklearn_auroc_multilabel, average=average, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_auroc_multilabel, average=average, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index cdb76ffce31..a88af1bd009 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -34,14 +34,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_avg_precision_binary(preds, target, ignore_index=None): +def _reference_sklearn_avg_precision_binary(preds, target, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -66,7 +66,7 @@ def test_binary_average_precision(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryAveragePrecision, - reference_metric=partial(_sklearn_avg_precision_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_avg_precision_binary, ignore_index=ignore_index), metric_args={ "thresholds": None, "ignore_index": ignore_index, @@ -83,7 +83,7 @@ def test_binary_average_precision_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_average_precision, - reference_metric=partial(_sklearn_avg_precision_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_avg_precision_binary, ignore_index=ignore_index), metric_args={ "thresholds": None, "ignore_index": ignore_index, @@ -142,7 +142,7 @@ def test_binary_average_precision_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(ap1, ap2) -def _sklearn_avg_precision_multiclass(preds, target, average="macro", ignore_index=None): +def _reference_sklearn_avg_precision_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -182,7 +182,9 @@ def test_multiclass_average_precision(self, inputs, average, ddp, ignore_index): preds=preds, target=target, metric_class=MulticlassAveragePrecision, - reference_metric=partial(_sklearn_avg_precision_multiclass, average=average, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_avg_precision_multiclass, average=average, ignore_index=ignore_index + ), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -202,7 +204,9 @@ def test_multiclass_average_precision_functional(self, inputs, average, ignore_i preds=preds, target=target, metric_functional=multiclass_average_precision, - reference_metric=partial(_sklearn_avg_precision_multiclass, average=average, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_avg_precision_multiclass, average=average, ignore_index=ignore_index + ), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -266,10 +270,10 @@ def test_multiclass_average_precision_threshold_arg(self, inputs, average): assert torch.allclose(ap1, ap2) -def _sklearn_avg_precision_multilabel(preds, target, average="macro", ignore_index=None): +def _reference_sklearn_avg_precision_multilabel(preds, target, average="macro", ignore_index=None): if average == "micro": - return _sklearn_avg_precision_binary(preds.flatten(), target.flatten(), ignore_index) - res = [_sklearn_avg_precision_binary(preds[:, i], target[:, i], ignore_index) for i in range(NUM_CLASSES)] + return _reference_sklearn_avg_precision_binary(preds.flatten(), target.flatten(), ignore_index) + res = [_reference_sklearn_avg_precision_binary(preds[:, i], target[:, i], ignore_index) for i in range(NUM_CLASSES)] if average == "macro": return np.array(res)[~np.isnan(res)].mean() if average == "weighted": @@ -298,7 +302,9 @@ def test_multilabel_average_precision(self, inputs, ddp, average, ignore_index): preds=preds, target=target, metric_class=MultilabelAveragePrecision, - reference_metric=partial(_sklearn_avg_precision_multilabel, average=average, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_avg_precision_multilabel, average=average, ignore_index=ignore_index + ), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, @@ -318,7 +324,9 @@ def test_multilabel_average_precision_functional(self, inputs, average, ignore_i preds=preds, target=target, metric_functional=multilabel_average_precision, - reference_metric=partial(_sklearn_avg_precision_multilabel, average=average, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_avg_precision_multilabel, average=average, ignore_index=ignore_index + ), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 83822bb9c4e..5660a9042f0 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -32,14 +32,14 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_13 from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _netcal_binary_calibration_error(preds, target, n_bins, norm, ignore_index): +def _reference_netcal_binary_calibration_error(preds, target, n_bins, norm, ignore_index): preds = preds.numpy().flatten() target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -68,7 +68,7 @@ def test_binary_calibration_error(self, inputs, ddp, n_bins, norm, ignore_index) target=target, metric_class=BinaryCalibrationError, reference_metric=partial( - _netcal_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index + _reference_netcal_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index ), metric_args={ "n_bins": n_bins, @@ -90,7 +90,7 @@ def test_binary_calibration_error_functional(self, inputs, n_bins, norm, ignore_ target=target, metric_functional=binary_calibration_error, reference_metric=partial( - _netcal_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index + _reference_netcal_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index ), metric_args={ "n_bins": n_bins, @@ -148,7 +148,7 @@ def test_binary_with_zero_pred(): assert binary_calibration_error(preds, target, n_bins=2, norm="l1") == torch.tensor(0.6) -def _netcal_multiclass_calibration_error(preds, target, n_bins, norm, ignore_index): +def _reference_netcal_multiclass_calibration_error(preds, target, n_bins, norm, ignore_index): preds = preds.numpy() target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -180,7 +180,7 @@ def test_multiclass_calibration_error(self, inputs, ddp, n_bins, norm, ignore_in target=target, metric_class=MulticlassCalibrationError, reference_metric=partial( - _netcal_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index + _reference_netcal_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index ), metric_args={ "num_classes": NUM_CLASSES, @@ -203,7 +203,7 @@ def test_multiclass_calibration_error_functional(self, inputs, n_bins, norm, ign target=target, metric_functional=multiclass_calibration_error, reference_metric=partial( - _netcal_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index + _reference_netcal_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index ), metric_args={ "num_classes": NUM_CLASSES, diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index c2f22736663..6b21b9be3ff 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -23,14 +23,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_cohen_kappa_binary(preds, target, weights=None, ignore_index=None): +def _reference_sklearn_cohen_kappa_binary(preds, target, weights=None, ignore_index=None): preds = preds.view(-1).numpy() target = target.view(-1).numpy() if np.issubdtype(preds.dtype, np.floating): @@ -60,7 +60,7 @@ def test_binary_cohen_kappa(self, inputs, ddp, weights, ignore_index): preds=preds, target=target, metric_class=BinaryCohenKappa, - reference_metric=partial(_sklearn_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "weights": weights, @@ -79,7 +79,7 @@ def test_binary_confusion_matrix_functional(self, inputs, weights, ignore_index) preds=preds, target=target, metric_functional=binary_cohen_kappa, - reference_metric=partial(_sklearn_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "weights": weights, @@ -129,7 +129,7 @@ def test_binary_confusion_matrix_dtypes_gpu(self, inputs, dtype): ) -def _sklearn_cohen_kappa_multiclass(preds, target, weights, ignore_index=None): +def _reference_sklearn_cohen_kappa_multiclass(preds, target, weights, ignore_index=None): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -159,7 +159,9 @@ def test_multiclass_cohen_kappa(self, inputs, ddp, weights, ignore_index): preds=preds, target=target, metric_class=MulticlassCohenKappa, - reference_metric=partial(_sklearn_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index + ), metric_args={ "num_classes": NUM_CLASSES, "weights": weights, @@ -178,7 +180,9 @@ def test_multiclass_confusion_matrix_functional(self, inputs, weights, ignore_in preds=preds, target=target, metric_functional=multiclass_cohen_kappa, - reference_metric=partial(_sklearn_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index + ), metric_args={ "num_classes": NUM_CLASSES, "weights": weights, diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 5265f9a64eb..6a1c1850d4f 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -32,14 +32,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None): +def _reference_sklearn_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None): preds = preds.view(-1).numpy() target = target.view(-1).numpy() if np.issubdtype(preds.dtype, np.floating): @@ -67,7 +67,9 @@ def test_binary_confusion_matrix(self, inputs, ddp, normalize, ignore_index): preds=preds, target=target, metric_class=BinaryConfusionMatrix, - reference_metric=partial(_sklearn_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index + ), metric_args={ "threshold": THRESHOLD, "normalize": normalize, @@ -86,7 +88,9 @@ def test_binary_confusion_matrix_functional(self, inputs, normalize, ignore_inde preds=preds, target=target, metric_functional=binary_confusion_matrix, - reference_metric=partial(_sklearn_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index + ), metric_args={ "threshold": THRESHOLD, "normalize": normalize, @@ -136,7 +140,7 @@ def test_binary_confusion_matrix_dtype_gpu(self, inputs, dtype): ) -def _sklearn_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index=None): +def _reference_sklearn_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index=None): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -165,7 +169,7 @@ def test_multiclass_confusion_matrix(self, inputs, ddp, normalize, ignore_index) target=target, metric_class=MulticlassConfusionMatrix, reference_metric=partial( - _sklearn_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index + _reference_sklearn_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index ), metric_args={ "num_classes": NUM_CLASSES, @@ -186,7 +190,7 @@ def test_multiclass_confusion_matrix_functional(self, inputs, normalize, ignore_ target=target, metric_functional=multiclass_confusion_matrix, reference_metric=partial( - _sklearn_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index + _reference_sklearn_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index ), metric_args={ "num_classes": NUM_CLASSES, @@ -247,7 +251,7 @@ def test_multiclass_overflow(): assert torch.allclose(res, torch.tensor(compare)) -def _sklearn_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None): +def _reference_sklearn_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -282,7 +286,7 @@ def test_multilabel_confusion_matrix(self, inputs, ddp, normalize, ignore_index) target=target, metric_class=MultilabelConfusionMatrix, reference_metric=partial( - _sklearn_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index + _reference_sklearn_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index ), metric_args={ "num_labels": NUM_CLASSES, @@ -303,7 +307,7 @@ def test_multilabel_confusion_matrix_functional(self, inputs, normalize, ignore_ target=target, metric_functional=multilabel_confusion_matrix, reference_metric=partial( - _sklearn_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index + _reference_sklearn_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index ), metric_args={ "num_labels": NUM_CLASSES, diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index f2cb1d18205..d3737b255a2 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -23,23 +23,23 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from unittests.classification._inputs import _input_binary, _input_binary_logits, _input_binary_prob +from unittests.classification._inputs import _input_multiclass as _input_mcls +from unittests.classification._inputs import _input_multiclass_logits as _input_mcls_logits +from unittests.classification._inputs import _input_multiclass_prob as _input_mcls_prob +from unittests.classification._inputs import _input_multiclass_with_missing_class as _input_miss_class +from unittests.classification._inputs import _input_multilabel as _input_mlb +from unittests.classification._inputs import _input_multilabel_logits as _input_mlb_logits +from unittests.classification._inputs import _input_multilabel_multidim as _input_mlmd +from unittests.classification._inputs import _input_multilabel_multidim_prob as _input_mlmd_prob +from unittests.classification._inputs import _input_multilabel_prob as _input_mlb_prob from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -def _scipy_dice( +def _reference_scipy_dice( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None, @@ -102,7 +102,7 @@ def test_dice_class(self, ddp, preds, target, ignore_index): preds=preds, target=target, metric_class=Dice, - reference_metric=partial(_scipy_dice, ignore_index=ignore_index), + reference_metric=partial(_reference_scipy_dice, ignore_index=ignore_index), metric_args={"ignore_index": ignore_index}, ) @@ -112,7 +112,7 @@ def test_dice_fn(self, preds, target, ignore_index): preds, target, metric_functional=dice, - reference_metric=partial(_scipy_dice, ignore_index=ignore_index), + reference_metric=partial(_reference_scipy_dice, ignore_index=ignore_index), metric_args={"ignore_index": ignore_index}, ) @@ -143,7 +143,7 @@ def test_dice_class(self, ddp, preds, target, ignore_index): preds=preds, target=target, metric_class=Dice, - reference_metric=partial(_scipy_dice, ignore_index=ignore_index), + reference_metric=partial(_reference_scipy_dice, ignore_index=ignore_index), metric_args={"ignore_index": ignore_index}, ) @@ -153,6 +153,6 @@ def test_dice_fn(self, preds, target, ignore_index): preds, target, metric_functional=dice, - reference_metric=partial(_scipy_dice, ignore_index=ignore_index), + reference_metric=partial(_reference_scipy_dice, ignore_index=ignore_index), metric_args={"ignore_index": ignore_index}, ) diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 79540a94535..048003c1699 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -22,14 +22,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index seed_all(42) -def _baseline_exact_match_multiclass(preds, target, ignore_index, multidim_average): +def _reference_exact_match_multiclass(preds, target, ignore_index, multidim_average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) preds = preds.numpy() @@ -68,7 +68,7 @@ def test_multiclass_exact_match(self, ddp, inputs, ignore_index, multidim_averag target=target, metric_class=MulticlassExactMatch, reference_metric=partial( - _baseline_exact_match_multiclass, + _reference_exact_match_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, ), @@ -94,7 +94,7 @@ def test_multiclass_exact_match_functional(self, inputs, ignore_index, multidim_ target=target, metric_functional=multiclass_exact_match, reference_metric=partial( - _baseline_exact_match_multiclass, + _reference_exact_match_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, ), @@ -147,7 +147,7 @@ def test_multiclass_exact_match_half_gpu(self, inputs, dtype): ) -def _baseline_exact_match_multilabel(preds, target, ignore_index, multidim_average): +def _reference_exact_match_multilabel(preds, target, ignore_index, multidim_average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -195,7 +195,7 @@ def test_multilabel_exact_match(self, ddp, inputs, ignore_index, multidim_averag target=target, metric_class=MultilabelExactMatch, reference_metric=partial( - _baseline_exact_match_multilabel, + _reference_exact_match_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, ), @@ -222,7 +222,7 @@ def test_multilabel_exact_match_functional(self, inputs, ignore_index, multidim_ target=target, metric_functional=multilabel_exact_match, reference_metric=partial( - _baseline_exact_match_multilabel, + _reference_exact_match_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, ), diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 4534a1b9259..a6cfc5f71b8 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -42,14 +42,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, multidim_average): +def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, multidim_average): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -106,7 +106,10 @@ def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, igno target=target, metric_class=module, reference_metric=partial( - _sklearn_fbeta_score_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_fbeta_score_binary, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, ), metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) @@ -126,7 +129,10 @@ def test_binary_fbeta_score_functional(self, inputs, module, functional, compare target=target, metric_functional=functional, reference_metric=partial( - _sklearn_fbeta_score_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_fbeta_score_binary, + sk_fn=compare, + ignore_index=ignore_index, + multidim_average=multidim_average, ), metric_args={ "threshold": THRESHOLD, @@ -177,7 +183,7 @@ def test_binary_fbeta_score_half_gpu(self, inputs, module, functional, compare, ) -def _sklearn_fbeta_score_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_fbeta_score_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) if multidim_average == "global": @@ -236,7 +242,7 @@ def test_multiclass_fbeta_score( target=target, metric_class=module, reference_metric=partial( - _sklearn_fbeta_score_multiclass, + _reference_sklearn_fbeta_score_multiclass, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -268,7 +274,7 @@ def test_multiclass_fbeta_score_functional( target=target, metric_functional=functional, reference_metric=partial( - _sklearn_fbeta_score_multiclass, + _reference_sklearn_fbeta_score_multiclass, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -362,7 +368,7 @@ def test_top_k( assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -def _sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average): if average == "micro": preds = preds.flatten() target = target.flatten() @@ -390,7 +396,7 @@ def _sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, a return None -def _sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average): fbeta_score, weights = [], [] for i in range(preds.shape[0]): if average == "micro": @@ -424,7 +430,7 @@ def _sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, av return None -def _sklearn_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -440,8 +446,8 @@ def _sklearn_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim average=average, ) if multidim_average == "global": - return _sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average) - return _sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average) + return _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average) @pytest.mark.parametrize("inputs", _multilabel_cases) @@ -482,7 +488,7 @@ def test_multilabel_fbeta_score( target=target, metric_class=module, reference_metric=partial( - _sklearn_fbeta_score_multilabel, + _reference_sklearn_fbeta_score_multilabel, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -515,7 +521,7 @@ def test_multilabel_fbeta_score_functional( target=target, metric_functional=functional, reference_metric=partial( - _sklearn_fbeta_score_multilabel, + _reference_sklearn_fbeta_score_multilabel, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index d286ceae140..811e2a55ab9 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -29,7 +29,7 @@ from torchmetrics.utilities.imports import _PYTHON_LOWER_3_8 from unittests import THRESHOLD -from unittests.classification.inputs import _group_cases +from unittests.classification._inputs import _group_cases from unittests.helpers import seed_all from unittests.helpers.testers import ( MetricTester, @@ -44,7 +44,7 @@ seed_all(42) -def _fairlearn_binary(preds, target, groups, ignore_index): +def _reference_fairlearn_binary(preds, target, groups, ignore_index): metrics = {"dp": selection_rate, "eo": true_positive_rate} preds = preds.numpy() @@ -81,7 +81,7 @@ def _assert_tensor(pl_result: Dict[str, Tensor], key: Optional[str] = None) -> N _core_assert_tensor(pl_result, key) -def _assert_allclose( +def _assert_allclose( # todo: unify with the general assert_allclose pl_result: Dict[str, Tensor], sk_result: Dict[str, Tensor], atol: float = 1e-8, key: Optional[str] = None ) -> None: if isinstance(pl_result, dict) and key is None: @@ -240,7 +240,7 @@ def test_binary_fairness(self, ddp, inputs, ignore_index): preds=preds, target=target, metric_class=BinaryFairness, - reference_metric=partial(_fairlearn_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_fairlearn_binary, ignore_index=ignore_index), metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "num_groups": 2, "task": "all"}, groups=groups, fragment_kwargs=True, @@ -257,7 +257,7 @@ def test_binary_fairness_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_fairness, - reference_metric=partial(_fairlearn_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_fairlearn_binary, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index ad6c2e199b4..8ccbbc9e1fb 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -33,19 +33,19 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_hamming_loss(target, preds): +def _reference_sklearn_hamming_loss(target, preds): score = sk_hamming_loss(target, preds) return score if not np.isnan(score) else 1.0 -def _sklearn_hamming_distance_binary(preds, target, ignore_index, multidim_average): +def _reference_sklearn_hamming_distance_binary(preds, target, ignore_index, multidim_average): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -60,14 +60,14 @@ def _sklearn_hamming_distance_binary(preds, target, ignore_index, multidim_avera if multidim_average == "global": target, preds = remove_ignore_index(target, preds, ignore_index) - return _sklearn_hamming_loss(target, preds) + return _reference_sklearn_hamming_loss(target, preds) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - res.append(_sklearn_hamming_loss(true, pred)) + res.append(_reference_sklearn_hamming_loss(true, pred)) return np.stack(res) @@ -94,7 +94,7 @@ def test_binary_hamming_distance(self, ddp, inputs, ignore_index, multidim_avera target=target, metric_class=BinaryHammingDistance, reference_metric=partial( - _sklearn_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) @@ -114,7 +114,7 @@ def test_binary_hamming_distance_functional(self, inputs, ignore_index, multidim target=target, metric_functional=binary_hamming_distance, reference_metric=partial( - _sklearn_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={ "threshold": THRESHOLD, @@ -164,12 +164,12 @@ def test_binary_hamming_distance_dtype_gpu(self, inputs, dtype): ) -def _sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average): +def _reference_sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) if average == "micro": - return _sklearn_hamming_loss(target, preds) + return _reference_sklearn_hamming_loss(target, preds) confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) hamming_per_class[np.isnan(hamming_per_class)] = 1.0 @@ -184,7 +184,7 @@ def _sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, ave return hamming_per_class -def _sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, average): +def _reference_sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, average): preds = preds.numpy() target = target.numpy() res = [] @@ -193,7 +193,7 @@ def _sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, aver true = true.flatten() true, pred = remove_ignore_index(true, pred, ignore_index) if average == "micro": - res.append(_sklearn_hamming_loss(true, pred)) + res.append(_reference_sklearn_hamming_loss(true, pred)) else: confmat = sk_confusion_matrix(true, pred, labels=list(range(NUM_CLASSES))) hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) @@ -212,12 +212,12 @@ def _sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, aver return np.stack(res, 0) -def _sklearn_hamming_distance_multiclass(preds, target, ignore_index, multidim_average, average): +def _reference_sklearn_hamming_distance_multiclass(preds, target, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) if multidim_average == "global": - return _sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average) - return _sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, average) + return _reference_sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average) + return _reference_sklearn_hamming_distance_multiclass_local(preds, target, ignore_index, average) @pytest.mark.parametrize("inputs", _multiclass_cases) @@ -244,7 +244,7 @@ def test_multiclass_hamming_distance(self, ddp, inputs, ignore_index, multidim_a target=target, metric_class=MulticlassHammingDistance, reference_metric=partial( - _sklearn_hamming_distance_multiclass, + _reference_sklearn_hamming_distance_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -273,7 +273,7 @@ def test_multiclass_hamming_distance_functional(self, inputs, ignore_index, mult target=target, metric_functional=multiclass_hamming_distance, reference_metric=partial( - _sklearn_hamming_distance_multiclass, + _reference_sklearn_hamming_distance_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -327,19 +327,19 @@ def test_multiclass_hamming_distance_dtype_gpu(self, inputs, dtype): ) -def _sklearn_hamming_distance_multilabel_global(preds, target, ignore_index, average): +def _reference_sklearn_hamming_distance_multilabel_global(preds, target, ignore_index, average): if average == "micro": preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - return _sklearn_hamming_loss(target, preds) + return _reference_sklearn_hamming_loss(target, preds) hamming, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - hamming.append(_sklearn_hamming_loss(true, pred)) + hamming.append(_reference_sklearn_hamming_loss(true, pred)) weights.append(confmat[1, 1] + confmat[1, 0]) res = np.stack(hamming, axis=0) @@ -355,19 +355,19 @@ def _sklearn_hamming_distance_multilabel_global(preds, target, ignore_index, ave return None -def _sklearn_hamming_distance_multilabel_local(preds, target, ignore_index, average): +def _reference_sklearn_hamming_distance_multilabel_local(preds, target, ignore_index, average): hamming, weights = [], [] for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() true, pred = remove_ignore_index(true, pred, ignore_index) - hamming.append(_sklearn_hamming_loss(true, pred)) + hamming.append(_reference_sklearn_hamming_loss(true, pred)) else: scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(_sklearn_hamming_loss(true, pred)) + scores.append(_reference_sklearn_hamming_loss(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) hamming.append(np.stack(scores)) @@ -387,7 +387,7 @@ def _sklearn_hamming_distance_multilabel_local(preds, target, ignore_index, aver return None -def _sklearn_hamming_distance_multilabel(preds, target, ignore_index, multidim_average, average): +def _reference_sklearn_hamming_distance_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -398,8 +398,8 @@ def _sklearn_hamming_distance_multilabel(preds, target, ignore_index, multidim_a target = target.reshape(*target.shape[:2], -1) if multidim_average == "global": - return _sklearn_hamming_distance_multilabel_global(preds, target, ignore_index, average) - return _sklearn_hamming_distance_multilabel_local(preds, target, ignore_index, average) + return _reference_sklearn_hamming_distance_multilabel_global(preds, target, ignore_index, average) + return _reference_sklearn_hamming_distance_multilabel_local(preds, target, ignore_index, average) @pytest.mark.parametrize("inputs", _multilabel_cases) @@ -426,7 +426,7 @@ def test_multilabel_hamming_distance(self, ddp, inputs, ignore_index, multidim_a target=target, metric_class=MultilabelHammingDistance, reference_metric=partial( - _sklearn_hamming_distance_multilabel, + _reference_sklearn_hamming_distance_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -456,7 +456,7 @@ def test_multilabel_hamming_distance_functional(self, inputs, ignore_index, mult target=target, metric_functional=multilabel_hamming_distance, reference_metric=partial( - _sklearn_hamming_distance_multilabel, + _reference_sklearn_hamming_distance_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 9c46aa59f10..6b9eaca1abd 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -25,13 +25,13 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index torch.manual_seed(42) -def _sklearn_binary_hinge_loss(preds, target, ignore_index): +def _reference_sklearn_binary_hinge_loss(preds, target, ignore_index): preds = preds.numpy().flatten() target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -58,7 +58,7 @@ def test_binary_hinge_loss(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryHingeLoss, - reference_metric=partial(_sklearn_binary_hinge_loss, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_binary_hinge_loss, ignore_index=ignore_index), metric_args={ "ignore_index": ignore_index, }, @@ -74,7 +74,7 @@ def test_binary_hinge_loss_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_hinge_loss, - reference_metric=partial(_sklearn_binary_hinge_loss, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_binary_hinge_loss, ignore_index=ignore_index), metric_args={ "ignore_index": ignore_index, }, @@ -118,7 +118,7 @@ def test_binary_hinge_loss_dtype_gpu(self, inputs, dtype): ) -def _sklearn_multiclass_hinge_loss(preds, target, multiclass_mode, ignore_index): +def _reference_sklearn_multiclass_hinge_loss(preds, target, multiclass_mode, ignore_index): preds = preds.numpy() target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -159,7 +159,7 @@ def test_multiclass_hinge_loss(self, inputs, ddp, multiclass_mode, ignore_index) target=target, metric_class=MulticlassHingeLoss, reference_metric=partial( - _sklearn_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index + _reference_sklearn_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index ), metric_args={ "num_classes": NUM_CLASSES, @@ -180,7 +180,7 @@ def test_multiclass_hinge_loss_functional(self, inputs, multiclass_mode, ignore_ target=target, metric_functional=multiclass_hinge_loss, reference_metric=partial( - _sklearn_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index + _reference_sklearn_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index ), metric_args={ "num_classes": NUM_CLASSES, diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 4dc52617491..c1e6354d57a 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -33,11 +33,11 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index -def _sklearn_jaccard_index_binary(preds, target, ignore_index=None): +def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None): preds = preds.view(-1).numpy() target = target.view(-1).numpy() if np.issubdtype(preds.dtype, np.floating): @@ -64,7 +64,7 @@ def test_binary_jaccard_index(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryJaccardIndex, - reference_metric=partial(_sklearn_jaccard_index_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_jaccard_index_binary, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, @@ -81,7 +81,7 @@ def test_binary_jaccard_index_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_jaccard_index, - reference_metric=partial(_sklearn_jaccard_index_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_jaccard_index_binary, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, @@ -129,7 +129,7 @@ def test_binary_jaccard_index_dtype_gpu(self, inputs, dtype): ) -def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average="macro"): +def _reference_sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average="macro"): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -163,7 +163,9 @@ def test_multiclass_jaccard_index(self, inputs, ddp, ignore_index, average): preds=preds, target=target, metric_class=MulticlassJaccardIndex, - reference_metric=partial(_sklearn_jaccard_index_multiclass, ignore_index=ignore_index, average=average), + reference_metric=partial( + _reference_sklearn_jaccard_index_multiclass, ignore_index=ignore_index, average=average + ), metric_args={ "num_classes": NUM_CLASSES, "ignore_index": ignore_index, @@ -182,7 +184,9 @@ def test_multiclass_jaccard_index_functional(self, inputs, ignore_index, average preds=preds, target=target, metric_functional=multiclass_jaccard_index, - reference_metric=partial(_sklearn_jaccard_index_multiclass, ignore_index=ignore_index, average=average), + reference_metric=partial( + _reference_sklearn_jaccard_index_multiclass, ignore_index=ignore_index, average=average + ), metric_args={ "num_classes": NUM_CLASSES, "ignore_index": ignore_index, @@ -229,7 +233,7 @@ def test_multiclass_jaccard_index_dtype_gpu(self, inputs, dtype): ) -def _sklearn_jaccard_index_multilabel(preds, target, ignore_index=None, average="macro"): +def _reference_sklearn_jaccard_index_multilabel(preds, target, ignore_index=None, average="macro"): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -242,7 +246,7 @@ def _sklearn_jaccard_index_multilabel(preds, target, ignore_index=None, average= return sk_jaccard_index(y_true=target, y_pred=preds, average=average) if average == "micro": - return _sklearn_jaccard_index_binary(torch.tensor(preds), torch.tensor(target), ignore_index) + return _reference_sklearn_jaccard_index_binary(torch.tensor(preds), torch.tensor(target), ignore_index) scores, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i], target[:, i] @@ -276,7 +280,9 @@ def test_multilabel_jaccard_index(self, inputs, ddp, ignore_index, average): preds=preds, target=target, metric_class=MultilabelJaccardIndex, - reference_metric=partial(_sklearn_jaccard_index_multilabel, ignore_index=ignore_index, average=average), + reference_metric=partial( + _reference_sklearn_jaccard_index_multilabel, ignore_index=ignore_index, average=average + ), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, @@ -295,7 +301,9 @@ def test_multilabel_jaccard_index_functional(self, inputs, ignore_index, average preds=preds, target=target, metric_functional=multilabel_jaccard_index, - reference_metric=partial(_sklearn_jaccard_index_multilabel, ignore_index=ignore_index, average=average), + reference_metric=partial( + _reference_sklearn_jaccard_index_multilabel, ignore_index=ignore_index, average=average + ), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 99baeee09cc..f8c0801b5ad 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -32,14 +32,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_matthews_corrcoef_binary(preds, target, ignore_index=None): +def _reference_sklearn_matthews_corrcoef_binary(preds, target, ignore_index=None): preds = preds.view(-1).numpy() target = target.view(-1).numpy() if np.issubdtype(preds.dtype, np.floating): @@ -66,7 +66,7 @@ def test_binary_matthews_corrcoef(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryMatthewsCorrCoef, - reference_metric=partial(_sklearn_matthews_corrcoef_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_matthews_corrcoef_binary, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, @@ -83,7 +83,7 @@ def test_binary_matthews_corrcoef_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_matthews_corrcoef, - reference_metric=partial(_sklearn_matthews_corrcoef_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_matthews_corrcoef_binary, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, "ignore_index": ignore_index, @@ -131,7 +131,7 @@ def test_binary_matthews_corrcoef_dtype_gpu(self, inputs, dtype): ) -def _sklearn_matthews_corrcoef_multiclass(preds, target, ignore_index=None): +def _reference_sklearn_matthews_corrcoef_multiclass(preds, target, ignore_index=None): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -158,7 +158,7 @@ def test_multiclass_matthews_corrcoef(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=MulticlassMatthewsCorrCoef, - reference_metric=partial(_sklearn_matthews_corrcoef_multiclass, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_matthews_corrcoef_multiclass, ignore_index=ignore_index), metric_args={ "num_classes": NUM_CLASSES, "ignore_index": ignore_index, @@ -175,7 +175,7 @@ def test_multiclass_matthews_corrcoef_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=multiclass_matthews_corrcoef, - reference_metric=partial(_sklearn_matthews_corrcoef_multiclass, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_matthews_corrcoef_multiclass, ignore_index=ignore_index), metric_args={ "num_classes": NUM_CLASSES, "ignore_index": ignore_index, @@ -221,7 +221,7 @@ def test_multiclass_matthews_corrcoef_dtype_gpu(self, inputs, dtype): ) -def _sklearn_matthews_corrcoef_multilabel(preds, target, ignore_index=None): +def _reference_sklearn_matthews_corrcoef_multilabel(preds, target, ignore_index=None): preds = preds.view(-1).numpy() target = target.view(-1).numpy() if np.issubdtype(preds.dtype, np.floating): @@ -248,7 +248,7 @@ def test_multilabel_matthews_corrcoef(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=MultilabelMatthewsCorrCoef, - reference_metric=partial(_sklearn_matthews_corrcoef_multilabel, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_matthews_corrcoef_multilabel, ignore_index=ignore_index), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, @@ -265,7 +265,7 @@ def test_multilabel_matthews_corrcoef_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=multilabel_matthews_corrcoef, - reference_metric=partial(_sklearn_matthews_corrcoef_multilabel, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_matthews_corrcoef_multilabel, ignore_index=ignore_index), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index 0d02d4e488e..201a6c80ce7 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -34,7 +34,7 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -53,7 +53,7 @@ def _precision_at_recall_x_multilabel(predictions, targets, min_recall): return float(max_precision), float(best_threshold) -def _sklearn_precision_at_fixed_recall_binary(preds, target, min_recall, ignore_index=None): +def _reference_sklearn_precision_at_fixed_recall_binary(preds, target, min_recall, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -80,7 +80,7 @@ def test_binary_precision_at_fixed_recall(self, inputs, ddp, min_recall, ignore_ target=target, metric_class=BinaryPrecisionAtFixedRecall, reference_metric=partial( - _sklearn_precision_at_fixed_recall_binary, min_recall=min_recall, ignore_index=ignore_index + _reference_sklearn_precision_at_fixed_recall_binary, min_recall=min_recall, ignore_index=ignore_index ), metric_args={ "min_recall": min_recall, @@ -101,7 +101,7 @@ def test_binary_precision_at_fixed_recall_functional(self, inputs, min_recall, i target=target, metric_functional=binary_precision_at_fixed_recall, reference_metric=partial( - _sklearn_precision_at_fixed_recall_binary, min_recall=min_recall, ignore_index=ignore_index + _reference_sklearn_precision_at_fixed_recall_binary, min_recall=min_recall, ignore_index=ignore_index ), metric_args={ "min_recall": min_recall, @@ -164,7 +164,7 @@ def test_binary_precision_at_fixed_recall_threshold_arg(self, inputs, min_recall assert torch.allclose(r1, r2) -def _sklearn_precision_at_fixed_recall_multiclass(preds, target, min_recall, ignore_index=None): +def _reference_sklearn_precision_at_fixed_recall_multiclass(preds, target, min_recall, ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -201,7 +201,9 @@ def test_multiclass_precision_at_fixed_recall(self, inputs, ddp, min_recall, ign target=target, metric_class=MulticlassPrecisionAtFixedRecall, reference_metric=partial( - _sklearn_precision_at_fixed_recall_multiclass, min_recall=min_recall, ignore_index=ignore_index + _reference_sklearn_precision_at_fixed_recall_multiclass, + min_recall=min_recall, + ignore_index=ignore_index, ), metric_args={ "min_recall": min_recall, @@ -223,7 +225,9 @@ def test_multiclass_precision_at_fixed_recall_functional(self, inputs, min_recal target=target, metric_functional=multiclass_precision_at_fixed_recall, reference_metric=partial( - _sklearn_precision_at_fixed_recall_multiclass, min_recall=min_recall, ignore_index=ignore_index + _reference_sklearn_precision_at_fixed_recall_multiclass, + min_recall=min_recall, + ignore_index=ignore_index, ), metric_args={ "min_recall": min_recall, @@ -290,10 +294,10 @@ def test_multiclass_precision_at_fixed_recall_threshold_arg(self, inputs, min_re assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) -def _sklearn_precision_at_fixed_recall_multilabel(preds, target, min_recall, ignore_index=None): +def _reference_sklearn_precision_at_fixed_recall_multilabel(preds, target, min_recall, ignore_index=None): precision, thresholds = [], [] for i in range(NUM_CLASSES): - res = _sklearn_precision_at_fixed_recall_binary(preds[:, i], target[:, i], min_recall, ignore_index) + res = _reference_sklearn_precision_at_fixed_recall_binary(preds[:, i], target[:, i], min_recall, ignore_index) precision.append(res[0]) thresholds.append(res[1]) return precision, thresholds @@ -319,7 +323,9 @@ def test_multilabel_precision_at_fixed_recall(self, inputs, ddp, min_recall, ign target=target, metric_class=MultilabelPrecisionAtFixedRecall, reference_metric=partial( - _sklearn_precision_at_fixed_recall_multilabel, min_recall=min_recall, ignore_index=ignore_index + _reference_sklearn_precision_at_fixed_recall_multilabel, + min_recall=min_recall, + ignore_index=ignore_index, ), metric_args={ "min_recall": min_recall, @@ -341,7 +347,9 @@ def test_multilabel_precision_at_fixed_recall_functional(self, inputs, min_recal target=target, metric_functional=multilabel_precision_at_fixed_recall, reference_metric=partial( - _sklearn_precision_at_fixed_recall_multilabel, min_recall=min_recall, ignore_index=ignore_index + _reference_sklearn_precision_at_fixed_recall_multilabel, + min_recall=min_recall, + ignore_index=ignore_index, ), metric_args={ "min_recall": min_recall, diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index f438dd50e21..d1c588de3b5 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -42,14 +42,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average): +def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -106,7 +106,7 @@ def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, target=target, metric_class=module, reference_metric=partial( - _sklearn_precision_recall_binary, + _reference_sklearn_precision_recall_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -131,7 +131,7 @@ def test_binary_precision_recall_functional( target=target, metric_functional=functional, reference_metric=partial( - _sklearn_precision_recall_binary, + _reference_sklearn_precision_recall_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -184,7 +184,7 @@ def test_binary_precision_recall_half_gpu(self, inputs, module, functional, comp ) -def _sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_precision_recall_multiclass(preds, target, sk_fn, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) @@ -241,7 +241,7 @@ def test_multiclass_precision_recall( target=target, metric_class=module, reference_metric=partial( - _sklearn_precision_recall_multiclass, + _reference_sklearn_precision_recall_multiclass, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -273,7 +273,7 @@ def test_multiclass_precision_recall_functional( target=target, metric_functional=functional, reference_metric=partial( - _sklearn_precision_recall_multiclass, + _reference_sklearn_precision_recall_multiclass, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -362,7 +362,7 @@ def test_top_k( assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) -def _sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average): if average == "micro": preds = preds.flatten() target = target.flatten() @@ -390,7 +390,7 @@ def _sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_ind return None -def _sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average): +def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average): precision_recall, weights = [], [] for i in range(preds.shape[0]): if average == "micro": @@ -424,7 +424,7 @@ def _sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_inde return None -def _sklearn_precision_recall_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): +def _reference_sklearn_precision_recall_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -440,8 +440,8 @@ def _sklearn_precision_recall_multilabel(preds, target, sk_fn, ignore_index, mul average=average, ) if multidim_average == "global": - return _sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average) - return _sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average) + return _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average) @pytest.mark.parametrize("inputs", _multilabel_cases) @@ -478,7 +478,7 @@ def test_multilabel_precision_recall( target=target, metric_class=module, reference_metric=partial( - _sklearn_precision_recall_multilabel, + _reference_sklearn_precision_recall_multilabel, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, @@ -511,7 +511,7 @@ def test_multilabel_precision_recall_functional( target=target, metric_functional=functional, reference_metric=partial( - _sklearn_precision_recall_multilabel, + _reference_sklearn_precision_recall_multilabel, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average, diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 9c1d4263b99..2dccb2e4176 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -33,14 +33,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_precision_recall_curve_binary(preds, target, ignore_index=None): +def _reference_sklearn_precision_recall_curve_binary(preds, target, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -65,7 +65,7 @@ def test_binary_precision_recall_curve(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryPrecisionRecallCurve, - reference_metric=partial(_sklearn_precision_recall_curve_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_precision_recall_curve_binary, ignore_index=ignore_index), metric_args={ "thresholds": None, "ignore_index": ignore_index, @@ -82,7 +82,7 @@ def test_binary_precision_recall_curve_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_precision_recall_curve, - reference_metric=partial(_sklearn_precision_recall_curve_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_precision_recall_curve_binary, ignore_index=ignore_index), metric_args={ "thresholds": None, "ignore_index": ignore_index, @@ -153,7 +153,7 @@ def test_binary_error_on_wrong_dtypes(self, inputs): binary_precision_recall_curve(preds[0].long(), target[0]) -def _sklearn_precision_recall_curve_multiclass(preds, target, ignore_index=None): +def _reference_sklearn_precision_recall_curve_multiclass(preds, target, ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -189,7 +189,7 @@ def test_multiclass_precision_recall_curve(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=MulticlassPrecisionRecallCurve, - reference_metric=partial(_sklearn_precision_recall_curve_multiclass, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_precision_recall_curve_multiclass, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -207,7 +207,7 @@ def test_multiclass_precision_recall_curve_functional(self, inputs, ignore_index preds=preds, target=target, metric_functional=multiclass_precision_recall_curve, - reference_metric=partial(_sklearn_precision_recall_curve_multiclass, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_precision_recall_curve_multiclass, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -300,10 +300,10 @@ def test_multiclass_average(self, inputs, average, thresholds): ) -def _sklearn_precision_recall_curve_multilabel(preds, target, ignore_index=None): +def _reference_sklearn_precision_recall_curve_multilabel(preds, target, ignore_index=None): precision, recall, thresholds = [], [], [] for i in range(NUM_CLASSES): - res = _sklearn_precision_recall_curve_binary(preds[:, i], target[:, i], ignore_index) + res = _reference_sklearn_precision_recall_curve_binary(preds[:, i], target[:, i], ignore_index) precision.append(res[0]) recall.append(res[1]) thresholds.append(res[2]) @@ -328,7 +328,7 @@ def test_multilabel_precision_recall_curve(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=MultilabelPrecisionRecallCurve, - reference_metric=partial(_sklearn_precision_recall_curve_multilabel, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_precision_recall_curve_multilabel, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, @@ -346,7 +346,7 @@ def test_multilabel_precision_recall_curve_functional(self, inputs, ignore_index preds=preds, target=target, metric_functional=multilabel_precision_recall_curve, - reference_metric=partial(_sklearn_precision_recall_curve_multilabel, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_precision_recall_curve_multilabel, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index 9801467dea0..f322ce44442 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -32,14 +32,14 @@ ) from unittests import NUM_CLASSES -from unittests.classification.inputs import _multilabel_cases +from unittests.classification._inputs import _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index seed_all(42) -def _sklearn_ranking(preds, target, fn, ignore_index): +def _reference_sklearn_ranking(preds, target, fn, ignore_index): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -78,7 +78,7 @@ def test_multilabel_ranking(self, inputs, metric, functional_metric, ref_metric, preds=preds, target=target, metric_class=metric, - reference_metric=partial(_sklearn_ranking, fn=ref_metric, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_ranking, fn=ref_metric, ignore_index=ignore_index), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, @@ -95,7 +95,7 @@ def test_multilabel_ranking_functional(self, inputs, metric, functional_metric, preds=preds, target=target, metric_functional=functional_metric, - reference_metric=partial(_sklearn_ranking, fn=ref_metric, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_ranking, fn=ref_metric, ignore_index=ignore_index), metric_args={ "num_labels": NUM_CLASSES, "ignore_index": ignore_index, diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 7eee9282073..740fafe965e 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -34,7 +34,7 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -53,7 +53,7 @@ def _recall_at_precision_x_multilabel(predictions, targets, min_precision): return float(max_recall), float(best_threshold) -def _sklearn_recall_at_fixed_precision_binary(preds, target, min_precision, ignore_index=None): +def _reference_sklearn_recall_at_fixed_precision_binary(preds, target, min_precision, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -80,7 +80,9 @@ def test_binary_recall_at_fixed_precision(self, inputs, ddp, min_precision, igno target=target, metric_class=BinaryRecallAtFixedPrecision, reference_metric=partial( - _sklearn_recall_at_fixed_precision_binary, min_precision=min_precision, ignore_index=ignore_index + _reference_sklearn_recall_at_fixed_precision_binary, + min_precision=min_precision, + ignore_index=ignore_index, ), metric_args={ "min_precision": min_precision, @@ -101,7 +103,9 @@ def test_binary_recall_at_fixed_precision_functional(self, inputs, min_precision target=target, metric_functional=binary_recall_at_fixed_precision, reference_metric=partial( - _sklearn_recall_at_fixed_precision_binary, min_precision=min_precision, ignore_index=ignore_index + _reference_sklearn_recall_at_fixed_precision_binary, + min_precision=min_precision, + ignore_index=ignore_index, ), metric_args={ "min_precision": min_precision, @@ -164,7 +168,7 @@ def test_binary_recall_at_fixed_precision_threshold_arg(self, inputs, min_precis assert torch.allclose(r1, r2) -def _sklearn_recall_at_fixed_precision_multiclass(preds, target, min_precision, ignore_index=None): +def _reference_sklearn_recall_at_fixed_precision_multiclass(preds, target, min_precision, ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -201,7 +205,9 @@ def test_multiclass_recall_at_fixed_precision(self, inputs, ddp, min_precision, target=target, metric_class=MulticlassRecallAtFixedPrecision, reference_metric=partial( - _sklearn_recall_at_fixed_precision_multiclass, min_precision=min_precision, ignore_index=ignore_index + _reference_sklearn_recall_at_fixed_precision_multiclass, + min_precision=min_precision, + ignore_index=ignore_index, ), metric_args={ "min_precision": min_precision, @@ -223,7 +229,9 @@ def test_multiclass_recall_at_fixed_precision_functional(self, inputs, min_preci target=target, metric_functional=multiclass_recall_at_fixed_precision, reference_metric=partial( - _sklearn_recall_at_fixed_precision_multiclass, min_precision=min_precision, ignore_index=ignore_index + _reference_sklearn_recall_at_fixed_precision_multiclass, + min_precision=min_precision, + ignore_index=ignore_index, ), metric_args={ "min_precision": min_precision, @@ -290,10 +298,12 @@ def test_multiclass_recall_at_fixed_precision_threshold_arg(self, inputs, min_pr assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) -def _sklearn_recall_at_fixed_precision_multilabel(preds, target, min_precision, ignore_index=None): +def _reference_sklearn_recall_at_fixed_precision_multilabel(preds, target, min_precision, ignore_index=None): recall, thresholds = [], [] for i in range(NUM_CLASSES): - res = _sklearn_recall_at_fixed_precision_binary(preds[:, i], target[:, i], min_precision, ignore_index) + res = _reference_sklearn_recall_at_fixed_precision_binary( + preds[:, i], target[:, i], min_precision, ignore_index + ) recall.append(res[0]) thresholds.append(res[1]) return recall, thresholds @@ -319,7 +329,9 @@ def test_multilabel_recall_at_fixed_precision(self, inputs, ddp, min_precision, target=target, metric_class=MultilabelRecallAtFixedPrecision, reference_metric=partial( - _sklearn_recall_at_fixed_precision_multilabel, min_precision=min_precision, ignore_index=ignore_index + _reference_sklearn_recall_at_fixed_precision_multilabel, + min_precision=min_precision, + ignore_index=ignore_index, ), metric_args={ "min_precision": min_precision, @@ -341,7 +353,9 @@ def test_multilabel_recall_at_fixed_precision_functional(self, inputs, min_preci target=target, metric_functional=multilabel_recall_at_fixed_precision, reference_metric=partial( - _sklearn_recall_at_fixed_precision_multilabel, min_precision=min_precision, ignore_index=ignore_index + _reference_sklearn_recall_at_fixed_precision_multilabel, + min_precision=min_precision, + ignore_index=ignore_index, ), metric_args={ "min_precision": min_precision, diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index b69dfc0c74b..d63219052b1 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -24,14 +24,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_roc_binary(preds, target, ignore_index=None): +def _reference_sklearn_roc_binary(preds, target, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -58,7 +58,7 @@ def test_binary_roc(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=BinaryROC, - reference_metric=partial(_sklearn_roc_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_roc_binary, ignore_index=ignore_index), metric_args={ "thresholds": None, "ignore_index": ignore_index, @@ -75,7 +75,7 @@ def test_binary_roc_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=binary_roc, - reference_metric=partial(_sklearn_roc_binary, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_roc_binary, ignore_index=ignore_index), metric_args={ "thresholds": None, "ignore_index": ignore_index, @@ -134,7 +134,7 @@ def test_binary_roc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(t1, t2) -def _sklearn_roc_multiclass(preds, target, ignore_index=None): +def _reference_sklearn_roc_multiclass(preds, target, ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -172,7 +172,7 @@ def test_multiclass_roc(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=MulticlassROC, - reference_metric=partial(_sklearn_roc_multiclass, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_roc_multiclass, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -190,7 +190,7 @@ def test_multiclass_roc_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=multiclass_roc, - reference_metric=partial(_sklearn_roc_multiclass, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_roc_multiclass, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_classes": NUM_CLASSES, @@ -267,10 +267,10 @@ def test_multiclass_average(self, inputs, average, thresholds): ) -def _sklearn_roc_multilabel(preds, target, ignore_index=None): +def _reference_sklearn_roc_multilabel(preds, target, ignore_index=None): fpr, tpr, thresholds = [], [], [] for i in range(NUM_CLASSES): - res = _sklearn_roc_binary(preds[:, i], target[:, i], ignore_index) + res = _reference_sklearn_roc_binary(preds[:, i], target[:, i], ignore_index) fpr.append(res[0]) tpr.append(res[1]) thresholds.append(res[2]) @@ -295,7 +295,7 @@ def test_multilabel_roc(self, inputs, ddp, ignore_index): preds=preds, target=target, metric_class=MultilabelROC, - reference_metric=partial(_sklearn_roc_multilabel, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_roc_multilabel, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, @@ -313,7 +313,7 @@ def test_multilabel_roc_functional(self, inputs, ignore_index): preds=preds, target=target, metric_functional=multilabel_roc, - reference_metric=partial(_sklearn_roc_multilabel, ignore_index=ignore_index), + reference_metric=partial(_reference_sklearn_roc_multilabel, ignore_index=ignore_index), metric_args={ "thresholds": None, "num_labels": NUM_CLASSES, diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 18ab93ff2fc..b64a636b57f 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -36,7 +36,7 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -74,7 +74,7 @@ def _sensitivity_at_specificity_x_multilabel(predictions, targets, min_specifici return float(max_spec), float(best_threshold) -def _sklearn_sensitivity_at_specificity_binary(preds, target, min_specificity, ignore_index=None): +def _reference_sklearn_sensitivity_at_specificity_binary(preds, target, min_specificity, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -103,7 +103,9 @@ def test_binary_sensitivity_at_specificity(self, inputs, ddp, min_specificity, i target=target, metric_class=BinarySensitivityAtSpecificity, reference_metric=partial( - _sklearn_sensitivity_at_specificity_binary, min_specificity=min_specificity, ignore_index=ignore_index + _reference_sklearn_sensitivity_at_specificity_binary, + min_specificity=min_specificity, + ignore_index=ignore_index, ), metric_args={ "min_specificity": min_specificity, @@ -125,7 +127,9 @@ def test_binary_sensitivity_at_specificity_functional(self, inputs, min_specific target=target, metric_functional=binary_sensitivity_at_specificity, reference_metric=partial( - _sklearn_sensitivity_at_specificity_binary, min_specificity=min_specificity, ignore_index=ignore_index + _reference_sklearn_sensitivity_at_specificity_binary, + min_specificity=min_specificity, + ignore_index=ignore_index, ), metric_args={ "min_specificity": min_specificity, @@ -188,7 +192,7 @@ def test_binary_sensitivity_at_specificity_threshold_arg(self, inputs, min_speci assert torch.allclose(r1, r2) -def _sklearn_sensitivity_at_specificity_multiclass(preds, target, min_specificity, ignore_index=None): +def _reference_sklearn_sensitivity_at_specificity_multiclass(preds, target, min_specificity, ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -227,7 +231,7 @@ def test_multiclass_sensitivity_at_specificity(self, inputs, ddp, min_specificit target=target, metric_class=MulticlassSensitivityAtSpecificity, reference_metric=partial( - _sklearn_sensitivity_at_specificity_multiclass, + _reference_sklearn_sensitivity_at_specificity_multiclass, min_specificity=min_specificity, ignore_index=ignore_index, ), @@ -252,7 +256,7 @@ def test_multiclass_sensitivity_at_specificity_functional(self, inputs, min_spec target=target, metric_functional=multiclass_sensitivity_at_specificity, reference_metric=partial( - _sklearn_sensitivity_at_specificity_multiclass, + _reference_sklearn_sensitivity_at_specificity_multiclass, min_specificity=min_specificity, ignore_index=ignore_index, ), @@ -325,10 +329,12 @@ def test_multiclass_sensitivity_at_specificity_threshold_arg(self, inputs, min_s assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) -def _sklearn_sensitivity_at_specificity_multilabel(preds, target, min_specificity, ignore_index=None): +def _reference_sklearn_sensitivity_at_specificity_multilabel(preds, target, min_specificity, ignore_index=None): sensitivity, thresholds = [], [] for i in range(NUM_CLASSES): - res = _sklearn_sensitivity_at_specificity_binary(preds[:, i], target[:, i], min_specificity, ignore_index) + res = _reference_sklearn_sensitivity_at_specificity_binary( + preds[:, i], target[:, i], min_specificity, ignore_index + ) sensitivity.append(res[0]) thresholds.append(res[1]) return sensitivity, thresholds @@ -356,7 +362,7 @@ def test_multilabel_sensitivity_at_specificity(self, inputs, ddp, min_specificit target=target, metric_class=MultilabelSensitivityAtSpecificity, reference_metric=partial( - _sklearn_sensitivity_at_specificity_multilabel, + _reference_sklearn_sensitivity_at_specificity_multilabel, min_specificity=min_specificity, ignore_index=ignore_index, ), @@ -381,7 +387,7 @@ def test_multilabel_sensitivity_at_specificity_functional(self, inputs, min_spec target=target, metric_functional=multilabel_sensitivity_at_specificity, reference_metric=partial( - _sklearn_sensitivity_at_specificity_multilabel, + _reference_sklearn_sensitivity_at_specificity_multilabel, min_specificity=min_specificity, ignore_index=ignore_index, ), diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 824e8667e92..3aa2dcf6cba 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -33,7 +33,7 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index @@ -50,7 +50,7 @@ def _calc_specificity(tn, fp): return tn / denom -def _baseline_specificity_binary(preds, target, ignore_index, multidim_average): +def _reference_specificity_binary(preds, target, ignore_index, multidim_average): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -107,7 +107,7 @@ def test_binary_specificity(self, ddp, inputs, ignore_index, multidim_average): target=target, metric_class=BinarySpecificity, reference_metric=partial( - _baseline_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) @@ -127,7 +127,7 @@ def test_binary_specificity_functional(self, inputs, ignore_index, multidim_aver target=target, metric_functional=binary_specificity, reference_metric=partial( - _baseline_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={ "threshold": THRESHOLD, @@ -177,7 +177,7 @@ def test_binary_specificity_dtype_gpu(self, inputs, dtype): ) -def _baseline_specificity_multiclass_global(preds, target, ignore_index, average): +def _reference_specificity_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() @@ -206,7 +206,7 @@ def _baseline_specificity_multiclass_global(preds, target, ignore_index, average return None -def _baseline_specificity_multiclass_local(preds, target, ignore_index, average): +def _reference_specificity_multiclass_local(preds, target, ignore_index, average): preds = preds.numpy() target = target.numpy() @@ -239,12 +239,12 @@ def _baseline_specificity_multiclass_local(preds, target, ignore_index, average) return np.stack(res, 0) -def _baseline_specificity_multiclass(preds, target, ignore_index, multidim_average, average): +def _reference_specificity_multiclass(preds, target, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) if multidim_average == "global": - return _baseline_specificity_multiclass_global(preds, target, ignore_index, average) - return _baseline_specificity_multiclass_local(preds, target, ignore_index, average) + return _reference_specificity_multiclass_global(preds, target, ignore_index, average) + return _reference_specificity_multiclass_local(preds, target, ignore_index, average) @pytest.mark.parametrize("inputs", _multiclass_cases) @@ -271,7 +271,7 @@ def test_multiclass_specificity(self, ddp, inputs, ignore_index, multidim_averag target=target, metric_class=MulticlassSpecificity, reference_metric=partial( - _baseline_specificity_multiclass, + _reference_specificity_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -300,7 +300,7 @@ def test_multiclass_specificity_functional(self, inputs, ignore_index, multidim_ target=target, metric_functional=multiclass_specificity, reference_metric=partial( - _baseline_specificity_multiclass, + _reference_specificity_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -374,7 +374,7 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spe assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec) -def _baseline_specificity_multilabel_global(preds, target, ignore_index, average): +def _reference_specificity_multilabel_global(preds, target, ignore_index, average): tns, fps = [], [] for i in range(preds.shape[1]): p, t = preds[:, i].flatten(), target[:, i].flatten() @@ -402,7 +402,7 @@ def _baseline_specificity_multilabel_global(preds, target, ignore_index, average return None -def _baseline_specificity_multilabel_local(preds, target, ignore_index, average): +def _reference_specificity_multilabel_local(preds, target, ignore_index, average): specificity = [] for i in range(preds.shape[0]): tns, fps = [], [] @@ -435,7 +435,7 @@ def _baseline_specificity_multilabel_local(preds, target, ignore_index, average) return None -def _baseline_specificity_multilabel(preds, target, ignore_index, multidim_average, average): +def _reference_specificity_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -445,8 +445,8 @@ def _baseline_specificity_multilabel(preds, target, ignore_index, multidim_avera preds = preds.reshape(*preds.shape[:2], -1) target = target.reshape(*target.shape[:2], -1) if multidim_average == "global": - return _baseline_specificity_multilabel_global(preds, target, ignore_index, average) - return _baseline_specificity_multilabel_local(preds, target, ignore_index, average) + return _reference_specificity_multilabel_global(preds, target, ignore_index, average) + return _reference_specificity_multilabel_local(preds, target, ignore_index, average) @pytest.mark.parametrize("inputs", _multilabel_cases) @@ -473,7 +473,7 @@ def test_multilabel_specificity(self, ddp, inputs, ignore_index, multidim_averag target=target, metric_class=MultilabelSpecificity, reference_metric=partial( - _baseline_specificity_multilabel, + _reference_specificity_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -503,7 +503,7 @@ def test_multilabel_specificity_functional(self, inputs, ignore_index, multidim_ target=target, metric_functional=multilabel_specificity, reference_metric=partial( - _baseline_specificity_multilabel, + _reference_specificity_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 196f92367f7..674319787e0 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -35,7 +35,7 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -72,7 +72,7 @@ def _specificity_at_sensitivity_x_multilabel(predictions, targets, min_sensitivi return float(max_spec), float(best_threshold) -def _sklearn_specificity_at_sensitivity_binary(preds, target, min_sensitivity, ignore_index=None): +def _reference_sklearn_specificity_at_sensitivity_binary(preds, target, min_sensitivity, ignore_index=None): preds = preds.flatten().numpy() target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): @@ -99,7 +99,9 @@ def test_binary_specificity_at_sensitivity(self, inputs, ddp, min_sensitivity, i target=target, metric_class=BinarySpecificityAtSensitivity, reference_metric=partial( - _sklearn_specificity_at_sensitivity_binary, min_sensitivity=min_sensitivity, ignore_index=ignore_index + _reference_sklearn_specificity_at_sensitivity_binary, + min_sensitivity=min_sensitivity, + ignore_index=ignore_index, ), metric_args={ "min_sensitivity": min_sensitivity, @@ -120,7 +122,9 @@ def test_binary_specificity_at_sensitivity_functional(self, inputs, min_sensitiv target=target, metric_functional=binary_specificity_at_sensitivity, reference_metric=partial( - _sklearn_specificity_at_sensitivity_binary, min_sensitivity=min_sensitivity, ignore_index=ignore_index + _reference_sklearn_specificity_at_sensitivity_binary, + min_sensitivity=min_sensitivity, + ignore_index=ignore_index, ), metric_args={ "min_sensitivity": min_sensitivity, @@ -183,7 +187,7 @@ def test_binary_specificity_at_sensitivity_threshold_arg(self, inputs, min_sensi assert torch.allclose(r1, r2) -def _sklearn_specificity_at_sensitivity_multiclass(preds, target, min_sensitivity, ignore_index=None): +def _reference_sklearn_specificity_at_sensitivity_multiclass(preds, target, min_sensitivity, ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): @@ -220,7 +224,7 @@ def test_multiclass_specificity_at_sensitivity(self, inputs, ddp, min_sensitivit target=target, metric_class=MulticlassSpecificityAtSensitivity, reference_metric=partial( - _sklearn_specificity_at_sensitivity_multiclass, + _reference_sklearn_specificity_at_sensitivity_multiclass, min_sensitivity=min_sensitivity, ignore_index=ignore_index, ), @@ -244,7 +248,7 @@ def test_multiclass_specificity_at_sensitivity_functional(self, inputs, min_sens target=target, metric_functional=multiclass_specificity_at_sensitivity, reference_metric=partial( - _sklearn_specificity_at_sensitivity_multiclass, + _reference_sklearn_specificity_at_sensitivity_multiclass, min_sensitivity=min_sensitivity, ignore_index=ignore_index, ), @@ -317,10 +321,12 @@ def test_multiclass_specificity_at_sensitivity_threshold_arg(self, inputs, min_s assert all(torch.allclose(r1[i], r2[i]) for i in range(len(r1))) -def _sklearn_specificity_at_sensitivity_multilabel(preds, target, min_sensitivity, ignore_index=None): +def _reference_sklearn_specificity_at_sensitivity_multilabel(preds, target, min_sensitivity, ignore_index=None): specificity, thresholds = [], [] for i in range(NUM_CLASSES): - res = _sklearn_specificity_at_sensitivity_binary(preds[:, i], target[:, i], min_sensitivity, ignore_index) + res = _reference_sklearn_specificity_at_sensitivity_binary( + preds[:, i], target[:, i], min_sensitivity, ignore_index + ) specificity.append(res[0]) thresholds.append(res[1]) return specificity, thresholds @@ -346,7 +352,7 @@ def test_multilabel_specificity_at_sensitivity(self, inputs, ddp, min_sensitivit target=target, metric_class=MultilabelSpecificityAtSensitivity, reference_metric=partial( - _sklearn_specificity_at_sensitivity_multilabel, + _reference_sklearn_specificity_at_sensitivity_multilabel, min_sensitivity=min_sensitivity, ignore_index=ignore_index, ), @@ -370,7 +376,7 @@ def test_multilabel_specificity_at_sensitivity_functional(self, inputs, min_sens target=target, metric_functional=multilabel_specificity_at_sensitivity, reference_metric=partial( - _sklearn_specificity_at_sensitivity_multilabel, + _reference_sklearn_specificity_at_sensitivity_multilabel, min_sensitivity=min_sensitivity, ignore_index=ignore_index, ), diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index ef6f25bd7bf..b1e4d36e1ed 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -32,14 +32,14 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sklearn_stat_scores_binary(preds, target, ignore_index, multidim_average): +def _reference_sklearn_stat_scores_binary(preds, target, ignore_index, multidim_average): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -90,7 +90,7 @@ def test_binary_stat_scores(self, ddp, inputs, ignore_index, multidim_average): target=target, metric_class=BinaryStatScores, reference_metric=partial( - _sklearn_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) @@ -110,7 +110,7 @@ def test_binary_stat_scores_functional(self, inputs, ignore_index, multidim_aver target=target, metric_functional=binary_stat_scores, reference_metric=partial( - _sklearn_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average + _reference_sklearn_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average ), metric_args={ "threshold": THRESHOLD, @@ -160,7 +160,7 @@ def test_binary_stat_scores_dtype_gpu(self, inputs, dtype): ) -def _sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average): +def _reference_sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() target, preds = remove_ignore_index(target, preds, ignore_index) @@ -183,7 +183,7 @@ def _sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average) return None -def _sklearn_stat_scores_multiclass_local(preds, target, ignore_index, average): +def _reference_sklearn_stat_scores_multiclass_local(preds, target, ignore_index, average): preds = preds.numpy() target = target.numpy() @@ -210,12 +210,12 @@ def _sklearn_stat_scores_multiclass_local(preds, target, ignore_index, average): return np.stack(res, 0) -def _sklearn_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): +def _reference_sklearn_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) if multidim_average == "global": - return _sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average) - return _sklearn_stat_scores_multiclass_local(preds, target, ignore_index, average) + return _reference_sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average) + return _reference_sklearn_stat_scores_multiclass_local(preds, target, ignore_index, average) @pytest.mark.parametrize("inputs", _multiclass_cases) @@ -242,7 +242,7 @@ def test_multiclass_stat_scores(self, ddp, inputs, ignore_index, multidim_averag target=target, metric_class=MulticlassStatScores, reference_metric=partial( - _sklearn_stat_scores_multiclass, + _reference_sklearn_stat_scores_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -271,7 +271,7 @@ def test_multiclass_stat_scores_functional(self, inputs, ignore_index, multidim_ target=target, metric_functional=multiclass_stat_scores, reference_metric=partial( - _sklearn_stat_scores_multiclass, + _reference_sklearn_stat_scores_multiclass, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -382,7 +382,7 @@ def test_multiclass_overflow(): assert torch.allclose(res, torch.tensor(compare)) -def _sklearn_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): +def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() if np.issubdtype(preds.dtype, np.floating): @@ -457,7 +457,7 @@ def test_multilabel_stat_scores(self, ddp, inputs, ignore_index, multidim_averag target=target, metric_class=MultilabelStatScores, reference_metric=partial( - _sklearn_stat_scores_multilabel, + _reference_sklearn_stat_scores_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, @@ -487,7 +487,7 @@ def test_multilabel_stat_scores_functional(self, inputs, ignore_index, multidim_ target=target, metric_functional=multilabel_stat_scores, reference_metric=partial( - _sklearn_stat_scores_multilabel, + _reference_sklearn_stat_scores_multilabel, ignore_index=ignore_index, multidim_average=multidim_average, average=average, diff --git a/tests/unittests/clustering/inputs.py b/tests/unittests/clustering/_inputs.py similarity index 100% rename from tests/unittests/clustering/inputs.py rename to tests/unittests/clustering/_inputs.py diff --git a/tests/unittests/clustering/test_adjusted_mutual_info_score.py b/tests/unittests/clustering/test_adjusted_mutual_info_score.py index 674a39bece8..304ee18bddc 100644 --- a/tests/unittests/clustering/test_adjusted_mutual_info_score.py +++ b/tests/unittests/clustering/test_adjusted_mutual_info_score.py @@ -20,7 +20,7 @@ from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES -from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_adjusted_rand_score.py b/tests/unittests/clustering/test_adjusted_rand_score.py index 862f46c62b3..54e8c1b4577 100644 --- a/tests/unittests/clustering/test_adjusted_rand_score.py +++ b/tests/unittests/clustering/test_adjusted_rand_score.py @@ -17,7 +17,7 @@ from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score -from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_calinski_harabasz_score.py b/tests/unittests/clustering/test_calinski_harabasz_score.py index 9769ea4ee69..6071767364e 100644 --- a/tests/unittests/clustering/test_calinski_harabasz_score.py +++ b/tests/unittests/clustering/test_calinski_harabasz_score.py @@ -16,7 +16,7 @@ from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score -from unittests.clustering.inputs import _single_target_intrinsic1, _single_target_intrinsic2 +from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_davies_bouldin_score.py b/tests/unittests/clustering/test_davies_bouldin_score.py index 5d0efbbf7d5..8f3b4800a01 100644 --- a/tests/unittests/clustering/test_davies_bouldin_score.py +++ b/tests/unittests/clustering/test_davies_bouldin_score.py @@ -16,7 +16,7 @@ from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score -from unittests.clustering.inputs import _single_target_intrinsic1, _single_target_intrinsic2 +from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_dunn_index.py b/tests/unittests/clustering/test_dunn_index.py index 2169ba7935c..fc1500a0fd8 100644 --- a/tests/unittests/clustering/test_dunn_index.py +++ b/tests/unittests/clustering/test_dunn_index.py @@ -19,7 +19,7 @@ from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.functional.clustering.dunn_index import dunn_index -from unittests.clustering.inputs import ( +from unittests.clustering._inputs import ( _single_target_intrinsic1, _single_target_intrinsic2, ) @@ -29,7 +29,7 @@ seed_all(42) -def _np_dunn_index(data, labels, p): +def _reference_np_dunn_index(data, labels, p): unique_labels, inverse_indices = np.unique(labels, return_inverse=True) clusters = [data[inverse_indices == label_idx] for label_idx in range(len(unique_labels))] centroids = [c.mean(axis=0) for c in clusters] @@ -69,7 +69,7 @@ def test_dunn_index(self, data, labels, p, ddp): preds=data, target=labels, metric_class=DunnIndex, - reference_metric=partial(_np_dunn_index, p=p), + reference_metric=partial(_reference_np_dunn_index, p=p), metric_args={"p": p}, ) @@ -79,6 +79,6 @@ def test_dunn_index_functional(self, data, labels, p): preds=data, target=labels, metric_functional=dunn_index, - reference_metric=partial(_np_dunn_index, p=p), + reference_metric=partial(_reference_np_dunn_index, p=p), p=p, ) diff --git a/tests/unittests/clustering/test_fowlkes_mallows_index.py b/tests/unittests/clustering/test_fowlkes_mallows_index.py index 1e4f9799795..f880791454f 100644 --- a/tests/unittests/clustering/test_fowlkes_mallows_index.py +++ b/tests/unittests/clustering/test_fowlkes_mallows_index.py @@ -16,7 +16,7 @@ from torchmetrics.clustering import FowlkesMallowsIndex from torchmetrics.functional.clustering import fowlkes_mallows_index -from unittests.clustering.inputs import _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py index 2a6692793f9..db8224f2f5e 100644 --- a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -28,14 +28,14 @@ v_measure_score, ) -from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -def _sk_reference(preds, target, fn): +def _reference_sklearn_wrapper(preds, target, fn): """Compute reference values using sklearn.""" return fn(target, preds) @@ -75,7 +75,7 @@ def test_homogeneity_completeness_vmeasure( preds=preds, target=target, metric_class=modular_metric, - reference_metric=partial(_sk_reference, fn=reference_metric), + reference_metric=partial(_reference_sklearn_wrapper, fn=reference_metric), ) def test_homogeneity_completeness_vmeasure_functional( @@ -86,7 +86,7 @@ def test_homogeneity_completeness_vmeasure_functional( preds=preds, target=target, metric_functional=functional_metric, - reference_metric=partial(_sk_reference, fn=reference_metric), + reference_metric=partial(_reference_sklearn_wrapper, fn=reference_metric), ) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 7640452db5d..2a5fd2af1ad 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -18,7 +18,7 @@ from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES -from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_normalized_mutual_info_score.py b/tests/unittests/clustering/test_normalized_mutual_info_score.py index b1c78d1ce0e..e40b807958a 100644 --- a/tests/unittests/clustering/test_normalized_mutual_info_score.py +++ b/tests/unittests/clustering/test_normalized_mutual_info_score.py @@ -20,7 +20,7 @@ from torchmetrics.functional.clustering import normalized_mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES -from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 40a109c7424..9a2abf1a736 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -17,7 +17,7 @@ from torchmetrics.clustering.rand_score import RandScore from torchmetrics.functional.clustering.rand_score import rand_score -from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 078760294b0..96b02d16930 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -57,21 +57,21 @@ _ARGS_1 = {"things": {2}, "stuffs": {3}, "allow_unknown_preds_category": True} _ARGS_2 = {"things": {0, 1}, "stuffs": {6, 7}} -# TODO: Improve _compare_fn by calling https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py +# TODO: Improve _reference_fn by calling https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py # directly and compare at runtime on multiple examples. -def _compare_fn_0_0(preds, target) -> np.ndarray: +def _reference_fn_0_0(preds, target) -> np.ndarray: """Baseline result for the _INPUTS_0, _ARGS_0 combination.""" return np.array([0.7753]) -def _compare_fn_0_1(preds, target) -> np.ndarray: +def _reference_fn_0_1(preds, target) -> np.ndarray: """Baseline result for the _INPUTS_0, _ARGS_1 combination.""" return np.array([np.nan]) -def _compare_fn_1_2(preds, target) -> np.ndarray: +def _reference_fn_1_2(preds, target) -> np.ndarray: """Baseline result for the _INPUTS_1, _ARGS_2 combination.""" return np.array([23 / 30]) @@ -83,9 +83,9 @@ class TestModifiedPanopticQuality(MetricTester): @pytest.mark.parametrize( ("inputs", "args", "reference_metric"), [ - (_INPUTS_0, _ARGS_0, _compare_fn_0_0), - (_INPUTS_0, _ARGS_1, _compare_fn_0_1), - (_INPUTS_1, _ARGS_2, _compare_fn_1_2), + (_INPUTS_0, _ARGS_0, _reference_fn_0_0), + (_INPUTS_0, _ARGS_1, _reference_fn_0_1), + (_INPUTS_1, _ARGS_2, _reference_fn_1_2), ], ) def test_panoptic_quality_class(self, ddp, inputs, args, reference_metric): @@ -106,7 +106,7 @@ def test_panoptic_quality_functional(self): _INPUTS_0.preds, _INPUTS_0.target, metric_functional=modified_panoptic_quality, - reference_metric=_compare_fn_0_0, + reference_metric=_reference_fn_0_0, metric_args=_ARGS_0, ) diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 24aa4c1f688..9a2b801e0f4 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -64,21 +64,21 @@ _ARGS_1 = {"things": {2}, "stuffs": {3}, "allow_unknown_preds_category": True} _ARGS_2 = {"things": {0, 1}, "stuffs": {10, 11}} -# TODO: Improve _compare_fn by calling https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py +# TODO: Improve _reference_fn by calling https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py # directly and compare at runtime on multiple examples. -def _compare_fn_0_0(preds, target) -> np.ndarray: +def _reference_fn_0_0(preds, target) -> np.ndarray: """Baseline result for the _INPUTS_0, _ARGS_0 combination.""" return np.array([0.7753]) -def _compare_fn_0_1(preds, target) -> np.ndarray: +def _reference_fn_0_1(preds, target) -> np.ndarray: """Baseline result for the _INPUTS_0, _ARGS_1 combination.""" return np.array([np.nan]) -def _compare_fn_1_2(preds, target) -> np.ndarray: +def _reference_fn_1_2(preds, target) -> np.ndarray: """Baseline result for the _INPUTS_1, _ARGS_2 combination.""" return np.array([(2 / 3 + 1 + 2 / 3) / 3]) @@ -90,9 +90,9 @@ class TestPanopticQuality(MetricTester): @pytest.mark.parametrize( ("inputs", "args", "reference_metric"), [ - (_INPUTS_0, _ARGS_0, _compare_fn_0_0), - (_INPUTS_0, _ARGS_1, _compare_fn_0_1), - (_INPUTS_1, _ARGS_2, _compare_fn_1_2), + (_INPUTS_0, _ARGS_0, _reference_fn_0_0), + (_INPUTS_0, _ARGS_1, _reference_fn_0_1), + (_INPUTS_1, _ARGS_2, _reference_fn_1_2), ], ) def test_panoptic_quality_class(self, ddp, inputs, args, reference_metric): @@ -113,7 +113,7 @@ def test_panoptic_quality_functional(self): _INPUTS_0.preds, _INPUTS_0.target, metric_functional=panoptic_quality, - reference_metric=_compare_fn_0_0, + reference_metric=_reference_fn_0_0, metric_args=_ARGS_0, ) diff --git a/tests/unittests/image/test_csi.py b/tests/unittests/image/test_csi.py index 259d81b60bf..9b338167889 100644 --- a/tests/unittests/image/test_csi.py +++ b/tests/unittests/image/test_csi.py @@ -31,7 +31,7 @@ _inputs_2 = _Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE)) -def _calculate_ref_metric(preds: torch.Tensor, target: torch.Tensor, threshold: float): +def _reference_sklearn_jaccard(preds: torch.Tensor, target: torch.Tensor, threshold: float): """Calculate reference metric for `CriticalSuccessIndex`.""" preds, target = preds.numpy(), target.numpy() preds = preds >= threshold @@ -58,7 +58,7 @@ def test_csi_class(self, preds, target, threshold, ddp): preds=preds, target=target, metric_class=CriticalSuccessIndex, - reference_metric=partial(_calculate_ref_metric, threshold=threshold), + reference_metric=partial(_reference_sklearn_jaccard, threshold=threshold), metric_args={"threshold": threshold}, ) @@ -68,7 +68,7 @@ def test_csi_functional(self, preds, target, threshold): preds=preds, target=target, metric_functional=critical_success_index, - reference_metric=partial(_calculate_ref_metric, threshold=threshold), + reference_metric=partial(_reference_sklearn_jaccard, threshold=threshold), metric_args={"threshold": threshold}, ) diff --git a/tests/unittests/image/test_d_lambda.py b/tests/unittests/image/test_d_lambda.py index 7312a779f6f..be1ec3ffdf1 100644 --- a/tests/unittests/image/test_d_lambda.py +++ b/tests/unittests/image/test_d_lambda.py @@ -44,13 +44,7 @@ class _Input(NamedTuple): ]: preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append( - _Input( - preds=preds, - target=target, - p=p, - ) - ) + _inputs.append(_Input(preds=preds, target=target, p=p)) def _baseline_d_lambda(preds: np.ndarray, target: np.ndarray, p: int = 1) -> float: @@ -80,16 +74,12 @@ def _baseline_d_lambda(preds: np.ndarray, target: np.ndarray, p: int = 1) -> flo return (1.0 / (length * (length - 1)) * np.sum(diff)) ** (1.0 / p) -def _np_d_lambda(preds, target, p): +def _reference_numpy_d_lambda(preds, target, p): c, h, w = preds.shape[-3:] np_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() np_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - return _baseline_d_lambda( - np_preds, - np_target, - p=p, - ) + return _baseline_d_lambda(np_preds, np_target, p=p) @pytest.mark.parametrize( @@ -108,8 +98,8 @@ def test_d_lambda(self, preds, target, p, ddp): ddp, preds, target, - SpectralDistortionIndex, - partial(_np_d_lambda, p=p), + metric_class=SpectralDistortionIndex, + reference_metric=partial(_reference_numpy_d_lambda, p=p), metric_args={"p": p}, ) @@ -118,8 +108,8 @@ def test_d_lambda_functional(self, preds, target, p): self.run_functional_metric_test( preds, target, - spectral_distortion_index, - partial(_np_d_lambda, p=p), + metric_functional=spectral_distortion_index, + reference_metric=partial(_reference_numpy_d_lambda, p=p), metric_args={"p": p}, ) diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 5a763e3cb7c..fb2a0fdf2db 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -73,7 +73,7 @@ class _Input(NamedTuple): ) -def _baseline_d_s( +def _reference_d_s( preds: np.ndarray, ms: np.ndarray, pan: np.ndarray, @@ -125,7 +125,7 @@ def _np_d_s(preds, target, pan=None, pan_lr=None, norm_order=1, window_size=7): np_pan = pan.permute(0, 2, 3, 1).cpu().numpy() np_pan_lr = pan_lr.permute(0, 2, 3, 1).cpu().numpy() if pan_lr is not None else None - return _baseline_d_s( + return _reference_d_s( np_preds, np_ms, np_pan, diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index e6d81af5a48..110fa6bf4bc 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -45,7 +45,7 @@ class _Input(NamedTuple): _inputs.append(_Input(preds=preds, target=preds * coef, ratio=ratio)) -def _baseline_ergas( +def _reference_ergas( preds: Tensor, target: Tensor, ratio: float = 4, @@ -89,8 +89,8 @@ def test_ergas(self, reduction, preds, target, ratio, ddp): ddp, preds, target, - ErrorRelativeGlobalDimensionlessSynthesis, - partial(_baseline_ergas, ratio=ratio, reduction=reduction), + metric_class=ErrorRelativeGlobalDimensionlessSynthesis, + reference_metric=partial(_reference_ergas, ratio=ratio, reduction=reduction), metric_args={"ratio": ratio, "reduction": reduction}, ) @@ -99,8 +99,8 @@ def test_ergas_functional(self, reduction, preds, target, ratio): self.run_functional_metric_test( preds, target, - error_relative_global_dimensionless_synthesis, - partial(_baseline_ergas, ratio=ratio, reduction=reduction), + metric_functional=error_relative_global_dimensionless_synthesis, + reference_metric=partial(_reference_ergas, ratio=ratio, reduction=reduction), metric_args={"ratio": ratio, "reduction": reduction}, ) diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 75182338960..0a7171ab996 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -39,7 +39,9 @@ class _Input(NamedTuple): ) -def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool = False, reduction: str = "mean") -> Tensor: +def _reference_lpips( + img1: Tensor, img2: Tensor, net_type: str, normalize: bool = False, reduction: str = "mean" +) -> Tensor: """Comparison function for tm implementation.""" ref = LPIPS_reference(net=net_type) res = ref(img1, img2, normalize=normalize).detach().cpu().numpy() @@ -64,7 +66,7 @@ def test_lpips(self, net_type, ddp): preds=_inputs.img1, target=_inputs.img2, metric_class=LearnedPerceptualImagePatchSimilarity, - reference_metric=partial(_compare_fn, net_type=net_type), + reference_metric=partial(_reference_lpips, net_type=net_type), check_scriptable=False, check_state_dict=False, metric_args={"net_type": net_type}, @@ -76,7 +78,7 @@ def test_lpips_functional(self): preds=_inputs.img1, target=_inputs.img2, metric_functional=learned_perceptual_image_patch_similarity, - reference_metric=partial(_compare_fn, net_type="alex"), + reference_metric=partial(_reference_lpips, net_type="alex"), metric_args={"net_type": "alex"}, ) @@ -102,7 +104,7 @@ def test_normalize_arg(normalize): """Test that normalize argument works as expected.""" metric = LearnedPerceptualImagePatchSimilarity(net_type="squeeze", normalize=normalize) res = metric(_inputs.img1[0], _inputs.img2[1]) - res2 = _compare_fn(_inputs.img1[0], _inputs.img2[1], net_type="squeeze", normalize=normalize) + res2 = _reference_lpips(_inputs.img1[0], _inputs.img2[1], net_type="squeeze", normalize=normalize) assert res == res2 diff --git a/tests/unittests/image/test_mifid.py b/tests/unittests/image/test_mifid.py index fd96328e5a6..d5bdb95cf68 100644 --- a/tests/unittests/image/test_mifid.py +++ b/tests/unittests/image/test_mifid.py @@ -21,8 +21,11 @@ from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance, NoTrainInceptionV3 from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from unittests import _reference_cachier -def _compare_mifid(preds, target, cosine_distance_eps: float = 0.1): + +@_reference_cachier +def _reference_mifid(preds, target, cosine_distance_eps: float = 0.1): """Reference implementation. Implementation taken from: @@ -174,7 +177,7 @@ def test_compare_mifid(equal_size): for i in range(m // batch_size): metric.update(img2[batch_size * i : batch_size * (i + 1)].cuda(), real=False) - compare_val = _compare_mifid(img1, img2) + compare_val = _reference_mifid(img1, img2) tm_res = metric.compute() assert torch.allclose(tm_res.cpu(), torch.tensor(compare_val, dtype=tm_res.dtype), atol=1e-3) diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 51be57c4125..8d71617be53 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import pytest import torch @@ -39,7 +38,7 @@ ) -def _pytorch_ms_ssim(preds, target, data_range, kernel_size): +def _reference_ms_ssim(preds, target, data_range: float = 1.0, kernel_size: int = 11): return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size, size_average=False) @@ -62,8 +61,8 @@ def test_ms_ssim(self, preds, target, ddp): ddp, preds, target, - MultiScaleStructuralSimilarityIndexMeasure, - partial(_pytorch_ms_ssim, data_range=1.0, kernel_size=11), + metric_class=MultiScaleStructuralSimilarityIndexMeasure, + reference_metric=_reference_ms_ssim, metric_args={"data_range": 1.0, "kernel_size": 11}, ) @@ -72,8 +71,8 @@ def test_ms_ssim_functional(self, preds, target): self.run_functional_metric_test( preds, target, - multiscale_structural_similarity_index_measure, - partial(_pytorch_ms_ssim, data_range=1.0, kernel_size=11), + metric_functional=multiscale_structural_similarity_index_measure, + reference_metric=_reference_ms_ssim, metric_args={"data_range": 1.0, "kernel_size": 11}, ) diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index 84269398b51..0cfe9546017 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -58,7 +58,7 @@ def _to_sk_peak_signal_noise_ratio_inputs(value, dim): return inputs -def _skimage_psnr(preds, target, data_range, reduction, dim): +def _reference_skimage_psnr(preds, target, data_range, reduction, dim): if isinstance(data_range, tuple): preds = preds.clamp(min=data_range[0], max=data_range[1]) target = target.clamp(min=data_range[0], max=data_range[1]) @@ -72,8 +72,8 @@ def _skimage_psnr(preds, target, data_range, reduction, dim): ]) -def _base_e_sk_psnr(preds, target, data_range, reduction, dim): - return _skimage_psnr(preds, target, data_range, reduction, dim) * np.log(10) +def _reference_sklearn_psnr_log(preds, target, data_range, reduction, dim): + return _reference_skimage_psnr(preds, target, data_range, reduction, dim) * np.log(10) @pytest.mark.parametrize( @@ -91,8 +91,8 @@ def _base_e_sk_psnr(preds, target, data_range, reduction, dim): @pytest.mark.parametrize( "base, ref_metric", [ - (10.0, _skimage_psnr), - (2.718281828459045, _base_e_sk_psnr), + (10.0, _reference_skimage_psnr), + (2.718281828459045, _reference_sklearn_psnr_log), ], ) class TestPSNR(MetricTester): @@ -106,8 +106,8 @@ def test_psnr(self, preds, target, data_range, base, reduction, dim, ref_metric, ddp, preds, target, - PeakSignalNoiseRatio, - partial(ref_metric, data_range=data_range, reduction=reduction, dim=dim), + metric_class=PeakSignalNoiseRatio, + reference_metric=partial(ref_metric, data_range=data_range, reduction=reduction, dim=dim), metric_args=_args, ) @@ -117,8 +117,8 @@ def test_psnr_functional(self, preds, target, ref_metric, data_range, base, redu self.run_functional_metric_test( preds, target, - peak_signal_noise_ratio, - partial(ref_metric, data_range=data_range, reduction=reduction, dim=dim), + metric_functional=peak_signal_noise_ratio, + reference_metric=partial(ref_metric, data_range=data_range, reduction=reduction, dim=dim), metric_args=_args, ) diff --git a/tests/unittests/image/test_psnrb.py b/tests/unittests/image/test_psnrb.py index 78da355abb7..077af420601 100644 --- a/tests/unittests/image/test_psnrb.py +++ b/tests/unittests/image/test_psnrb.py @@ -35,7 +35,7 @@ ) -def _ref_metric(preds, target): +def _reference_psnrb(preds, target): """Reference implementation of PSNRB metric. Inspired by @@ -66,11 +66,18 @@ class TestPSNR(MetricTester): @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_psnr(self, preds, target, ddp): """Test that modular PSNRB metric returns the same result as the reference implementation.""" - self.run_class_metric_test(ddp, preds, target, PeakSignalNoiseRatioWithBlockedEffect, _ref_metric) + self.run_class_metric_test( + ddp, preds, target, metric_class=PeakSignalNoiseRatioWithBlockedEffect, reference_metric=_reference_psnrb + ) def test_psnr_functional(self, preds, target): """Test that functional PSNRB metric returns the same result as the reference implementation.""" - self.run_functional_metric_test(preds, target, peak_signal_noise_ratio_with_blocked_effect, _ref_metric) + self.run_functional_metric_test( + preds, + target, + metric_functional=peak_signal_noise_ratio_with_blocked_effect, + reference_metric=_reference_psnrb, + ) def test_psnr_half_cpu(self, preds, target): """Test that PSNRB metric works with half precision on cpu.""" diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index d09ddf7db27..e1e680beed6 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -15,7 +15,6 @@ from functools import partial from typing import Dict, List, NamedTuple -import numpy as np import pytest import torch from torch import Tensor @@ -26,7 +25,7 @@ from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester from unittests.image.test_d_lambda import _baseline_d_lambda -from unittests.image.test_d_s import _baseline_d_s +from unittests.image.test_d_s import _reference_d_s seed_all(42) @@ -76,45 +75,24 @@ class _Input(NamedTuple): ) -def _baseline_quality_with_no_reference( - preds: np.ndarray, - ms: np.ndarray, - pan: np.ndarray, - pan_lr: np.ndarray = None, - alpha: float = 1, - beta: float = 1, - norm_order: int = 1, - window_size: int = 7, -) -> float: - """NumPy based implementation of Quality with No Reference, which uses D_lambda and D_s.""" - d_lambda = _baseline_d_lambda(preds, ms, norm_order) - d_s = _baseline_d_s(preds, ms, pan, pan_lr, norm_order, window_size) - return (1 - d_lambda) ** alpha * (1 - d_s) ** beta - - -def _np_quality_with_no_reference(preds, target, pan=None, pan_lr=None, alpha=1, beta=1, norm_order=1, window_size=7): - np_preds = preds.permute(0, 2, 3, 1).cpu().numpy() +def _reference_numpy_quality_with_no_ref( + preds, target, pan=None, pan_lr=None, alpha=1, beta=1, norm_order=1, window_size=7 +): + preds = preds.permute(0, 2, 3, 1).cpu().numpy() if isinstance(target, dict): assert "ms" in target, "Expected `target` to contain 'ms'." - np_ms = target["ms"].permute(0, 2, 3, 1).cpu().numpy() + ms = target["ms"].permute(0, 2, 3, 1).cpu().numpy() assert "pan" in target, "Expected `target` to contain 'pan'." - np_pan = target["pan"].permute(0, 2, 3, 1).cpu().numpy() - np_pan_lr = target["pan_lr"].permute(0, 2, 3, 1).cpu().numpy() if "pan_lr" in target else None + pan = target["pan"].permute(0, 2, 3, 1).cpu().numpy() + pan_lr = target["pan_lr"].permute(0, 2, 3, 1).cpu().numpy() if "pan_lr" in target else None else: - np_ms = target.permute(0, 2, 3, 1).cpu().numpy() - np_pan = pan.permute(0, 2, 3, 1).cpu().numpy() - np_pan_lr = pan_lr.permute(0, 2, 3, 1).cpu().numpy() if pan_lr is not None else None - - return _baseline_quality_with_no_reference( - np_preds, - np_ms, - np_pan, - np_pan_lr, - alpha=alpha, - beta=beta, - norm_order=norm_order, - window_size=window_size, - ) + ms = target.permute(0, 2, 3, 1).cpu().numpy() + pan = pan.permute(0, 2, 3, 1).cpu().numpy() + pan_lr = pan_lr.permute(0, 2, 3, 1).cpu().numpy() if pan_lr is not None else None + + d_lambda = _baseline_d_lambda(preds, ms, norm_order) + d_s = _reference_d_s(preds, ms, pan, pan_lr, norm_order=norm_order, window_size=window_size) + return (1 - d_lambda) ** alpha * (1 - d_s) ** beta def _invoke_quality_with_no_reference(preds, target, ms, pan, pan_lr, alpha, beta, norm_order, window_size): @@ -141,7 +119,7 @@ def test_quality_with_no_reference(self, preds, target, ms, pan, pan_lr, alpha, preds, target, QualityWithNoReference, - partial(_np_quality_with_no_reference, norm_order=norm_order, window_size=window_size), + partial(_reference_numpy_quality_with_no_ref, norm_order=norm_order, window_size=window_size), metric_args={"alpha": alpha, "beta": beta, "norm_order": norm_order, "window_size": window_size}, ) @@ -154,7 +132,11 @@ def test_quality_with_no_reference_functional( ms, quality_with_no_reference, partial( - _np_quality_with_no_reference, alpha=alpha, beta=beta, norm_order=norm_order, window_size=window_size + _reference_numpy_quality_with_no_ref, + alpha=alpha, + beta=beta, + norm_order=norm_order, + window_size=window_size, ), metric_args={"alpha": alpha, "beta": beta, "norm_order": norm_order, "window_size": window_size}, fragment_kwargs=True, diff --git a/tests/unittests/image/test_rase.py b/tests/unittests/image/test_rase.py index 1a1488838c2..8015153fd19 100644 --- a/tests/unittests/image/test_rase.py +++ b/tests/unittests/image/test_rase.py @@ -44,7 +44,7 @@ class _InputWindowSized(NamedTuple): _inputs.append(_InputWindowSized(preds=preds, target=target, window_size=window_size)) -def _sewar_rase(preds, target, window_size): +def _reference_sewar_rase(preds, target, window_size): """Baseline implementation of metric. This custom implementation is necessary since sewar only supports single image and aggregation therefore needs @@ -83,8 +83,8 @@ def test_rase(self, preds, target, window_size, ddp): ddp, preds, target, - RelativeAverageSpectralError, - partial(_sewar_rase, window_size=window_size), + metric_class=RelativeAverageSpectralError, + reference_metric=partial(_reference_sewar_rase, window_size=window_size), metric_args={"window_size": window_size}, check_batch=False, ) @@ -94,7 +94,7 @@ def test_rase_functional(self, preds, target, window_size): self.run_functional_metric_test( preds, target, - relative_average_spectral_error, - partial(_sewar_rase, window_size=window_size), + metric_functional=relative_average_spectral_error, + reference_metric=partial(_reference_sewar_rase, window_size=window_size), metric_args={"window_size": window_size}, ) diff --git a/tests/unittests/image/test_rmse_sw.py b/tests/unittests/image/test_rmse_sw.py index 54c4757d8c8..a989de625a7 100644 --- a/tests/unittests/image/test_rmse_sw.py +++ b/tests/unittests/image/test_rmse_sw.py @@ -43,7 +43,7 @@ class _InputWindowSized(NamedTuple): _inputs.append(_InputWindowSized(preds=preds, target=target, window_size=window_size)) -def _sewar_rmse_sw(preds, target, window_size): +def _reference_sewar_rmse_sw(preds, target, window_size): rmse_mean = torch.tensor(0.0, dtype=preds.dtype) preds = preds.permute(0, 2, 3, 1).numpy() @@ -69,8 +69,8 @@ def test_rmse_sw(self, preds, target, window_size, ddp): ddp, preds, target, - RootMeanSquaredErrorUsingSlidingWindow, - partial(_sewar_rmse_sw, window_size=window_size), + metric_class=RootMeanSquaredErrorUsingSlidingWindow, + reference_metric=partial(_reference_sewar_rmse_sw, window_size=window_size), metric_args={"window_size": window_size}, ) @@ -79,7 +79,7 @@ def test_rmse_sw_functional(self, preds, target, window_size): self.run_functional_metric_test( preds, target, - root_mean_squared_error_using_sliding_window, - partial(_sewar_rmse_sw, window_size=window_size), + metric_functional=root_mean_squared_error_using_sliding_window, + reference_metric=partial(_reference_sewar_rmse_sw, window_size=window_size), metric_args={"window_size": window_size}, ) diff --git a/tests/unittests/image/test_sam.py b/tests/unittests/image/test_sam.py index fe9a9db6f5c..9cf83196866 100644 --- a/tests/unittests/image/test_sam.py +++ b/tests/unittests/image/test_sam.py @@ -40,11 +40,7 @@ _inputs.append(_Input(preds=preds, target=target)) -def _baseline_sam( - preds: Tensor, - target: Tensor, - reduction: str = "elementwise_mean", -) -> Tensor: +def _reference_sam(preds: Tensor, target: Tensor, reduction: str = "elementwise_mean") -> Tensor: """Baseline implementation of spectral angle mapper.""" reduction_options = ("elementwise_mean", "sum", "none") if reduction not in reduction_options: @@ -74,8 +70,8 @@ def test_sam(self, reduction, preds, target, ddp): ddp, preds, target, - SpectralAngleMapper, - partial(_baseline_sam, reduction=reduction), + metric_class=SpectralAngleMapper, + reference_metric=partial(_reference_sam, reduction=reduction), metric_args={"reduction": reduction}, ) @@ -84,8 +80,8 @@ def test_sam_functional(self, reduction, preds, target): self.run_functional_metric_test( preds, target, - spectral_angle_mapper, - partial(_baseline_sam, reduction=reduction), + metric_functional=spectral_angle_mapper, + reference_metric=partial(_reference_sam, reduction=reduction), metric_args={"reduction": reduction}, ) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 24ad09935fa..a84f2d83468 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -53,7 +53,7 @@ ) -def _skimage_ssim( +def _reference_skimage_ssim( preds, target, data_range, @@ -116,7 +116,7 @@ def _skimage_ssim( return results, fullimages -def _pt_ssim( +def _reference_msssim_ssim( preds, target, data_range, @@ -147,8 +147,8 @@ def test_ssim_sk(self, preds, target, sigma, data_range, ddp): ddp, preds, target, - StructuralSimilarityIndexMeasure, - partial(_skimage_ssim, data_range=data_range, sigma=sigma, kernel_size=None), + metric_class=StructuralSimilarityIndexMeasure, + reference_metric=partial(_reference_skimage_ssim, data_range=data_range, sigma=sigma, kernel_size=None), metric_args={ "data_range": data_range, "sigma": sigma, @@ -162,8 +162,8 @@ def test_ssim_pt(self, preds, target, sigma, ddp): ddp, preds, target, - StructuralSimilarityIndexMeasure, - partial(_pt_ssim, data_range=1.0, sigma=sigma), + metric_class=StructuralSimilarityIndexMeasure, + reference_metric=partial(_reference_msssim_ssim, data_range=1.0, sigma=sigma), metric_args={ "data_range": 1.0, "sigma": sigma, @@ -177,8 +177,8 @@ def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp): ddp, preds, target, - StructuralSimilarityIndexMeasure, - partial(_skimage_ssim, data_range=1.0, sigma=sigma, kernel_size=None), + metric_class=StructuralSimilarityIndexMeasure, + reference_metric=partial(_reference_skimage_ssim, data_range=1.0, sigma=sigma, kernel_size=None), metric_args={ "gaussian_kernel": False, "data_range": 1.0, @@ -192,8 +192,10 @@ def test_ssim_functional_sk(self, preds, target, sigma, reduction_arg): self.run_functional_metric_test( preds, target, - structural_similarity_index_measure, - partial(_skimage_ssim, data_range=1.0, sigma=sigma, kernel_size=None, reduction_arg=reduction_arg), + metric_functional=structural_similarity_index_measure, + reference_metric=partial( + _reference_skimage_ssim, data_range=1.0, sigma=sigma, kernel_size=None, reduction_arg=reduction_arg + ), metric_args={"data_range": 1.0, "sigma": sigma, "reduction": reduction_arg}, ) @@ -203,8 +205,8 @@ def test_ssim_functional_pt(self, preds, target, sigma, reduction_arg): self.run_functional_metric_test( preds, target, - structural_similarity_index_measure, - partial(_pt_ssim, data_range=1.0, sigma=sigma, reduction_arg=reduction_arg), + metric_functional=structural_similarity_index_measure, + reference_metric=partial(_reference_msssim_ssim, data_range=1.0, sigma=sigma, reduction_arg=reduction_arg), metric_args={"data_range": 1.0, "sigma": sigma, "reduction": reduction_arg}, ) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 59cb0da75cc..1842b693046 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -36,11 +36,11 @@ def update(self, img, *args: Any): super().update(img=img) -def _total_variaion_tester(preds, target, reduction="mean"): +def _total_variaion_wrapped(preds, target, reduction="mean"): return total_variation(preds, reduction) -def _total_variation_kornia_tester(preds, target, reduction): +def _reference_kornia_tv(preds, target, reduction): score = kornia_total_variation(preds).sum(-1) if reduction == "sum": return score.sum() @@ -82,7 +82,7 @@ def test_total_variation(self, preds, target, reduction, ddp): preds, target, TotalVariationTester, - partial(_total_variation_kornia_tester, reduction=reduction), + partial(_reference_kornia_tv, reduction=reduction), metric_args={"reduction": reduction}, ) @@ -91,8 +91,8 @@ def test_total_variation_functional(self, preds, target, reduction): self.run_functional_metric_test( preds, target, - _total_variaion_tester, - partial(_total_variation_kornia_tester, reduction=reduction), + _total_variaion_wrapped, + partial(_reference_kornia_tv, reduction=reduction), metric_args={"reduction": reduction}, ) @@ -102,13 +102,13 @@ def test_sam_half_cpu(self, preds, target, reduction): preds, target, TotalVariationTester, - _total_variaion_tester, + _total_variaion_wrapped, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_sam_half_gpu(self, preds, target, reduction): """Test for half precision on GPU.""" - self.run_precision_test_gpu(preds, target, TotalVariationTester, _total_variaion_tester) + self.run_precision_test_gpu(preds, target, TotalVariationTester, _total_variaion_wrapped) def test_correct_args(): diff --git a/tests/unittests/image/test_uqi.py b/tests/unittests/image/test_uqi.py index 85f8bc37a72..22ff2b7479b 100644 --- a/tests/unittests/image/test_uqi.py +++ b/tests/unittests/image/test_uqi.py @@ -55,7 +55,7 @@ class _InputMultichannel(NamedTuple): ) -def _skimage_uqi(preds, target, multichannel, kernel_size): +def _reference_skimage_uqi(preds, target, multichannel, kernel_size): c, h, w = preds.shape[-3:] sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() @@ -93,8 +93,8 @@ def test_uqi(self, preds, target, multichannel, kernel_size, ddp): ddp, preds, target, - UniversalImageQualityIndex, - partial(_skimage_uqi, multichannel=multichannel, kernel_size=kernel_size), + metric_class=UniversalImageQualityIndex, + reference_metric=partial(_reference_skimage_uqi, multichannel=multichannel, kernel_size=kernel_size), metric_args={"kernel_size": (kernel_size, kernel_size)}, ) @@ -103,8 +103,8 @@ def test_uqi_functional(self, preds, target, multichannel, kernel_size): self.run_functional_metric_test( preds, target, - universal_image_quality_index, - partial(_skimage_uqi, multichannel=multichannel, kernel_size=kernel_size), + metric_functional=universal_image_quality_index, + reference_metric=partial(_reference_skimage_uqi, multichannel=multichannel, kernel_size=kernel_size), metric_args={"kernel_size": (kernel_size, kernel_size)}, ) diff --git a/tests/unittests/image/test_vif.py b/tests/unittests/image/test_vif.py index 421c98ad1e1..bde00c969b5 100644 --- a/tests/unittests/image/test_vif.py +++ b/tests/unittests/image/test_vif.py @@ -35,7 +35,7 @@ ] -def _sewar_vif(preds, target, sigma_nsq=2): +def _reference_sewar_vif(preds, target, sigma_nsq=2): preds = torch.movedim(preds, 1, -1) target = torch.movedim(target, 1, -1) preds = preds.cpu().numpy() @@ -53,8 +53,12 @@ class TestVIF(MetricTester): @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_vif(self, preds, target, ddp): """Test class implementation of metric.""" - self.run_class_metric_test(ddp, preds, target, VisualInformationFidelity, _sewar_vif) + self.run_class_metric_test( + ddp, preds, target, metric_class=VisualInformationFidelity, reference_metric=_reference_sewar_vif + ) def test_vif_functional(self, preds, target): """Test functional implementation of metric.""" - self.run_functional_metric_test(preds, target, visual_information_fidelity, _sewar_vif) + self.run_functional_metric_test( + preds, target, metric_functional=visual_information_fidelity, reference_metric=_reference_sewar_vif + ) diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index 05421dc55b1..c7057226759 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -60,12 +60,12 @@ def compute(self): return super().compute().sum() -def _clip_iqa_tester(preds, target): +def _clip_iqa_wrapped(preds, target): """Tester function for `clip_image_quality_assessment` that supports two input arguments.""" return clip_image_quality_assessment(preds) -def _reference(preds, target, reduce=False): +def _reference_clip_iqa(preds, target, reduce=False): """Reference implementation of `CLIPImageQualityAssessment` metric.""" res = piq.CLIPIQA()(preds).squeeze() return res.sum() if reduce else res @@ -85,7 +85,7 @@ def test_clip_iqa(self, ddp): preds=torch.rand(2, 1, 3, 128, 128), target=torch.rand(2, 1, 3, 128, 128), metric_class=CLIPTesterClass, - reference_metric=partial(_reference, reduce=True), + reference_metric=partial(_reference_clip_iqa, reduce=True), check_scriptable=False, check_state_dict=False, ) @@ -98,8 +98,8 @@ def test_clip_iqa_functional(self, shapes): self.run_functional_metric_test( preds=img, target=img, - metric_functional=_clip_iqa_tester, - reference_metric=_reference, + metric_functional=_clip_iqa_wrapped, + reference_metric=_reference_clip_iqa, ) diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index e506dc89d74..110266e6525 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -48,7 +48,7 @@ class _InputImagesCaptions(NamedTuple): ) -def _compare_fn(preds, target, model_name_or_path): +def _reference_clip_score(preds, target, model_name_or_path): processor = _CLIPProcessor.from_pretrained(model_name_or_path) model = _CLIPModel.from_pretrained(model_name_or_path) inputs = processor(text=target, images=[p.cpu() for p in preds], return_tensors="pt", padding=True) @@ -75,7 +75,7 @@ def test_clip_score(self, inputs, model_name_or_path, ddp): preds=preds, target=target, metric_class=CLIPScore, - reference_metric=partial(_compare_fn, model_name_or_path=model_name_or_path), + reference_metric=partial(_reference_clip_score, model_name_or_path=model_name_or_path), metric_args={"model_name_or_path": model_name_or_path}, check_scriptable=False, check_state_dict=False, @@ -90,7 +90,7 @@ def test_clip_score_functional(self, inputs, model_name_or_path): preds=preds, target=target, metric_functional=clip_score, - reference_metric=partial(_compare_fn, model_name_or_path=model_name_or_path), + reference_metric=partial(_reference_clip_score, model_name_or_path=model_name_or_path), metric_args={"model_name_or_path": model_name_or_path}, ) diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index ba5b1e4a6c8..4cebac73e05 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -62,7 +62,7 @@ def cramers_matrix_input(): return matrix -def _dython_cramers_v(preds, target, bias_correction, nan_strategy, nan_replace_value): +def _reference_dython_cramers_v(preds, target, bias_correction, nan_strategy, nan_replace_value): preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target @@ -81,7 +81,7 @@ def _dython_cramers_v_matrix(matrix, bias_correction, nan_strategy, nan_replace_ cramers_v_matrix_value = torch.ones(num_variables, num_variables) for i, j in itertools.combinations(range(num_variables), 2): x, y = matrix[:, i], matrix[:, j] - cramers_v_matrix_value[i, j] = cramers_v_matrix_value[j, i] = _dython_cramers_v( + cramers_v_matrix_value[i, j] = cramers_v_matrix_value[j, i] = _reference_dython_cramers_v( x, y, bias_correction, nan_strategy, nan_replace_value ) return cramers_v_matrix_value @@ -113,7 +113,7 @@ def test_cramers_v(self, ddp, preds, target, bias_correction, nan_strategy, nan_ "num_classes": NUM_CLASSES, } reference_metric = partial( - _dython_cramers_v, + _reference_dython_cramers_v, bias_correction=bias_correction, nan_strategy=nan_strategy, nan_replace_value=nan_replace_value, @@ -135,7 +135,7 @@ def test_cramers_v_functional(self, preds, target, bias_correction, nan_strategy "nan_replace_value": nan_replace_value, } reference_metric = partial( - _dython_cramers_v, + _reference_dython_cramers_v, bias_correction=bias_correction, nan_strategy=nan_strategy, nan_replace_value=nan_replace_value, diff --git a/tests/unittests/nominal/test_fleiss_kappa.py b/tests/unittests/nominal/test_fleiss_kappa.py index 7fbae7383de..1538b116ab2 100644 --- a/tests/unittests/nominal/test_fleiss_kappa.py +++ b/tests/unittests/nominal/test_fleiss_kappa.py @@ -27,7 +27,7 @@ NUM_CATEGORIES = NUM_CLASSES -def _compare_func(preds, target, mode): +def _reference_fleiss_kappa(preds, target, mode): if mode == "probs": counts = np.zeros((preds.shape[0], preds.shape[1])) preds = preds.argmax(dim=1) @@ -92,7 +92,7 @@ def test_fleiss_kappa(self, ddp, preds, target, mode): preds=preds, target=target, metric_class=WrappedFleissKappa, - reference_metric=partial(_compare_func, mode=mode), + reference_metric=partial(_reference_fleiss_kappa, mode=mode), metric_args={"mode": mode}, ) @@ -102,7 +102,7 @@ def test_fleiss_kappa_functional(self, preds, target, mode): preds, target, metric_functional=wrapped_fleiss_kappa, - reference_metric=partial(_compare_func, mode=mode), + reference_metric=partial(_reference_fleiss_kappa, mode=mode), metric_args={"mode": mode}, ) diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index 6e3bd90a66d..44bf1c0e415 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -55,7 +55,7 @@ def pearson_matrix_input(): ) -def _pd_pearsons_t(preds, target): +def _reference_pd_pearsons_t(preds, target): preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target preds, target = preds.numpy().astype(int), target.numpy().astype(int) @@ -65,12 +65,12 @@ def _pd_pearsons_t(preds, target): return torch.tensor(t) -def _pd_pearsons_t_matrix(matrix): +def _reference_pd_pearsons_t_matrix(matrix): num_variables = matrix.shape[1] pearsons_t_matrix_value = torch.ones(num_variables, num_variables) for i, j in itertools.combinations(range(num_variables), 2): x, y = matrix[:, i], matrix[:, j] - pearsons_t_matrix_value[i, j] = pearsons_t_matrix_value[j, i] = _pd_pearsons_t(x, y) + pearsons_t_matrix_value[i, j] = pearsons_t_matrix_value[j, i] = _reference_pd_pearsons_t(x, y) return pearsons_t_matrix_value @@ -96,14 +96,14 @@ def test_pearsons_ta(self, ddp, preds, target): preds=preds, target=target, metric_class=PearsonsContingencyCoefficient, - reference_metric=_pd_pearsons_t, + reference_metric=_reference_pd_pearsons_t, metric_args=metric_args, ) def test_pearsons_t_functional(self, preds, target): """Test functional implementation of metric.""" self.run_functional_metric_test( - preds, target, metric_functional=pearsons_contingency_coefficient, reference_metric=_pd_pearsons_t + preds, target, metric_functional=pearsons_contingency_coefficient, reference_metric=_reference_pd_pearsons_t ) def test_pearsons_t_differentiability(self, preds, target): @@ -122,5 +122,5 @@ def test_pearsons_t_differentiability(self, preds, target): def test_pearsons_contingency_coefficient_matrix(pearson_matrix_input): """Test matrix version of metric works as expected.""" tm_score = pearsons_contingency_coefficient_matrix(pearson_matrix_input) - reference_score = _pd_pearsons_t_matrix(pearson_matrix_input) + reference_score = _reference_pd_pearsons_t_matrix(pearson_matrix_input) assert torch.allclose(tm_score, reference_score) diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index 16a66bda48b..b7ae4b29507 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -62,7 +62,7 @@ def theils_u_matrix_input(): return matrix -def _dython_theils_u(preds, target, nan_strategy, nan_replace_value): +def _reference_dython_theils_u(preds, target, nan_strategy, nan_replace_value): preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target @@ -75,13 +75,13 @@ def _dython_theils_u(preds, target, nan_strategy, nan_replace_value): return torch.tensor(v) -def _dython_theils_u_matrix(matrix, nan_strategy, nan_replace_value): +def _reference_dython_theils_u_matrix(matrix, nan_strategy, nan_replace_value): num_variables = matrix.shape[1] theils_u_matrix_value = torch.ones(num_variables, num_variables) for i, j in itertools.combinations(range(num_variables), 2): x, y = matrix[:, i], matrix[:, j] - theils_u_matrix_value[i, j] = _dython_theils_u(x, y, nan_strategy, nan_replace_value) - theils_u_matrix_value[j, i] = _dython_theils_u(y, x, nan_strategy, nan_replace_value) + theils_u_matrix_value[i, j] = _reference_dython_theils_u(x, y, nan_strategy, nan_replace_value) + theils_u_matrix_value[j, i] = _reference_dython_theils_u(y, x, nan_strategy, nan_replace_value) return theils_u_matrix_value @@ -109,7 +109,7 @@ def test_theils_u(self, ddp, preds, target, nan_strategy, nan_replace_value): "num_classes": NUM_CLASSES, } reference_metric = partial( - _dython_theils_u, + _reference_dython_theils_u, nan_strategy=nan_strategy, nan_replace_value=nan_replace_value, ) @@ -129,7 +129,7 @@ def test_theils_u_functional(self, preds, target, nan_strategy, nan_replace_valu "nan_replace_value": nan_replace_value, } reference_metric = partial( - _dython_theils_u, + _reference_dython_theils_u, nan_strategy=nan_strategy, nan_replace_value=nan_replace_value, ) @@ -158,5 +158,5 @@ def test_theils_u_differentiability(self, preds, target, nan_strategy, nan_repla def test_theils_u_matrix(theils_u_matrix_input, nan_strategy, nan_replace_value): """Test matrix version of metric works as expected.""" tm_score = theils_u_matrix(theils_u_matrix_input, nan_strategy, nan_replace_value) - reference_score = _dython_theils_u_matrix(theils_u_matrix_input, nan_strategy, nan_replace_value) + reference_score = _reference_dython_theils_u_matrix(theils_u_matrix_input, nan_strategy, nan_replace_value) assert torch.allclose(tm_score, reference_score, atol=1e-6) diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index bbab8e36785..48102ac6f34 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -52,7 +52,7 @@ def tschuprows_matrix_input(): ) -def _pd_tschuprows_t(preds, target): +def _reference_pd_tschuprows_t(preds, target): preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target preds, target = preds.numpy().astype(int), target.numpy().astype(int) @@ -62,12 +62,12 @@ def _pd_tschuprows_t(preds, target): return torch.tensor(t) -def _pd_tschuprows_t_matrix(matrix): +def _reference_pd_tschuprows_t_matrix(matrix): num_variables = matrix.shape[1] tschuprows_t_matrix_value = torch.ones(num_variables, num_variables) for i, j in itertools.combinations(range(num_variables), 2): x, y = matrix[:, i], matrix[:, j] - tschuprows_t_matrix_value[i, j] = tschuprows_t_matrix_value[j, i] = _pd_tschuprows_t(x, y) + tschuprows_t_matrix_value[i, j] = tschuprows_t_matrix_value[j, i] = _reference_pd_tschuprows_t(x, y) return tschuprows_t_matrix_value @@ -93,7 +93,7 @@ def test_tschuprows_ta(self, ddp, preds, target): preds=preds, target=target, metric_class=TschuprowsT, - reference_metric=_pd_tschuprows_t, + reference_metric=_reference_pd_tschuprows_t, metric_args=metric_args, ) @@ -101,7 +101,11 @@ def test_tschuprows_t_functional(self, preds, target): """Test functional implementation of metric.""" metric_args = {"bias_correction": False} self.run_functional_metric_test( - preds, target, metric_functional=tschuprows_t, reference_metric=_pd_tschuprows_t, metric_args=metric_args + preds, + target, + metric_functional=tschuprows_t, + reference_metric=_reference_pd_tschuprows_t, + metric_args=metric_args, ) def test_tschuprows_t_differentiability(self, preds, target): @@ -120,5 +124,5 @@ def test_tschuprows_t_differentiability(self, preds, target): def test_tschuprows_t_matrix(tschuprows_matrix_input): """Test matrix version of metric works as expected.""" tm_score = tschuprows_t_matrix(tschuprows_matrix_input, bias_correction=False) - reference_score = _pd_tschuprows_t_matrix(tschuprows_matrix_input) + reference_score = _reference_pd_tschuprows_t_matrix(tschuprows_matrix_input) assert torch.allclose(tm_score, reference_score) diff --git a/tests/unittests/regression/test_concordance.py b/tests/unittests/regression/test_concordance.py index 376385ad570..69668772021 100644 --- a/tests/unittests/regression/test_concordance.py +++ b/tests/unittests/regression/test_concordance.py @@ -49,7 +49,7 @@ ) -def _scipy_concordance(preds, target): +def _reference_scipy_concordance(preds, target): preds, target = preds.numpy(), target.numpy() if preds.ndim == 2: mean_pred = np.mean(preds, axis=0) @@ -89,13 +89,13 @@ def test_concordance_corrcoef(self, preds, target, ddp): preds, target, ConcordanceCorrCoef, - _scipy_concordance, + _reference_scipy_concordance, metric_args={"num_outputs": num_outputs}, ) def test_concordance_corrcoef_functional(self, preds, target): """Test functional implementation of metric.""" - self.run_functional_metric_test(preds, target, concordance_corrcoef, _scipy_concordance) + self.run_functional_metric_test(preds, target, concordance_corrcoef, _reference_scipy_concordance) def test_concordance_corrcoef_differentiability(self, preds, target): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" diff --git a/tests/unittests/regression/test_cosine_similarity.py b/tests/unittests/regression/test_cosine_similarity.py index c65300e82ff..184676526c7 100644 --- a/tests/unittests/regression/test_cosine_similarity.py +++ b/tests/unittests/regression/test_cosine_similarity.py @@ -26,7 +26,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -35,12 +35,12 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) -def _ref_metric(preds, target, reduction): +def _reference_sklearn_cosine(preds, target, reduction): sk_preds = preds.numpy() sk_target = target.numpy() result_array = sk_cosine(sk_target, sk_preds) @@ -57,8 +57,8 @@ def _ref_metric(preds, target, reduction): @pytest.mark.parametrize( "preds, target, ref_metric", [ - (_single_target_inputs.preds, _single_target_inputs.target, _ref_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _ref_metric), + (_single_target_inputs.preds, _single_target_inputs.target, _reference_sklearn_cosine), + (_multi_target_inputs.preds, _multi_target_inputs.target, _reference_sklearn_cosine), ], ) class TestCosineSimilarity(MetricTester): diff --git a/tests/unittests/regression/test_explained_variance.py b/tests/unittests/regression/test_explained_variance.py index 8c6218ba45a..629e1ea7932 100644 --- a/tests/unittests/regression/test_explained_variance.py +++ b/tests/unittests/regression/test_explained_variance.py @@ -25,7 +25,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -34,8 +34,8 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) @@ -46,8 +46,8 @@ def _single_target_ref_metric(preds, target, sk_fn=explained_variance_score): def _multi_target_ref_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() + sk_preds = preds.view(-1, NUM_TARGETS).numpy() + sk_target = target.view(-1, NUM_TARGETS).numpy() return sk_fn(sk_target, sk_preds) diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 1a2d3d000fa..c7e5747a0ba 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -47,7 +47,7 @@ ) -def _scipy_kendall(preds, target, alternative, variant): +def _reference_scipy_kendall(preds, target, alternative, variant): metric_args = {} if _SCIPY_GREATER_EQUAL_1_8: metric_args = {"alternative": alternative or "two-sided"} # scipy cannot accept `None` @@ -94,7 +94,7 @@ def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp): """Test class implementation of metric.""" num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 t_test = bool(alternative is not None) - _sk_kendall_tau = partial(_scipy_kendall, alternative=alternative, variant=variant) + _sk_kendall_tau = partial(_reference_scipy_kendall, alternative=alternative, variant=variant) alternative = _adjust_alternative_to_scipy(alternative) self.run_class_metric_test( @@ -111,7 +111,7 @@ def test_kendall_rank_corrcoef_functional(self, preds, target, alternative, vari t_test = bool(alternative is not None) alternative = _adjust_alternative_to_scipy(alternative) metric_args = {"t_test": t_test, "alternative": alternative, "variant": variant} - _sk_kendall_tau = partial(_scipy_kendall, alternative=alternative, variant=variant) + _sk_kendall_tau = partial(_reference_scipy_kendall, alternative=alternative, variant=variant) self.run_functional_metric_test(preds, target, kendall_rank_corrcoef, _sk_kendall_tau, metric_args=metric_args) def test_kendall_rank_corrcoef_differentiability(self, preds, target, alternative, variant): diff --git a/tests/unittests/regression/test_log_cosh_error.py b/tests/unittests/regression/test_log_cosh_error.py index 0d6a9f1f14b..74b6214719b 100644 --- a/tests/unittests/regression/test_log_cosh_error.py +++ b/tests/unittests/regression/test_log_cosh_error.py @@ -25,7 +25,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -34,12 +34,12 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) -def _sk_log_cosh_error(preds, target): +def _reference_log_cosh_error(preds, target): preds, target = preds.numpy(), target.numpy() diff = preds - target if diff.ndim == 1: @@ -60,13 +60,13 @@ class TestLogCoshError(MetricTester): @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_log_cosh_error_class(self, ddp, preds, target): """Test class implementation of metric.""" - num_outputs = 1 if preds.ndim == 2 else num_targets + num_outputs = 1 if preds.ndim == 2 else NUM_TARGETS self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=LogCoshError, - reference_metric=_sk_log_cosh_error, + reference_metric=_reference_log_cosh_error, metric_args={"num_outputs": num_outputs}, ) @@ -76,12 +76,12 @@ def test_log_cosh_error_functional(self, preds, target): preds=preds, target=target, metric_functional=log_cosh_error, - reference_metric=_sk_log_cosh_error, + reference_metric=_reference_log_cosh_error, ) def test_log_cosh_error_differentiability(self, preds, target): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" - num_outputs = 1 if preds.ndim == 2 else num_targets + num_outputs = 1 if preds.ndim == 2 else NUM_TARGETS self.run_differentiability_test( preds=preds, target=target, diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 67e297f1b4f..c25882c3f37 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -47,7 +47,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -56,12 +56,12 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) -def _baseline_symmetric_mape( +def _reference_symmetric_mape( y_true: np.ndarray, y_pred: np.ndarray, sample_weight: Optional[np.ndarray] = None, @@ -114,11 +114,11 @@ def _baseline_symmetric_mape( return np.average(output_errors, weights=multioutput) -def _sk_weighted_mean_abs_percentage_error(target, preds): +def _reference_weighted_mean_abs_percentage_error(target, preds): return np.sum(np.abs(target - preds)) / np.sum(np.abs(target)) -def _single_target_ref_metric(preds, target, sk_fn, metric_args): +def _single_target_ref_wrapper(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -127,9 +127,9 @@ def _single_target_ref_metric(preds, target, sk_fn, metric_args): return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res -def _multi_target_ref_metric(preds, target, sk_fn, metric_args): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() +def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): + sk_preds = preds.view(-1, NUM_TARGETS).numpy() + sk_target = target.view(-1, NUM_TARGETS).numpy() sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} res = sk_fn(sk_target, sk_preds, **sk_kwargs) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res @@ -138,8 +138,8 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): @pytest.mark.parametrize( "preds, target, ref_metric", [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_metric), + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper), ], ) @pytest.mark.parametrize( @@ -147,20 +147,20 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): [ (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": num_targets}), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": NUM_TARGETS}), (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), ( SymmetricMeanAbsolutePercentageError, symmetric_mean_absolute_percentage_error, - _baseline_symmetric_mape, + _reference_symmetric_mape, {}, ), (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), ( WeightedMeanAbsolutePercentageError, weighted_mean_absolute_percentage_error, - _sk_weighted_mean_abs_percentage_error, + _reference_weighted_mean_abs_percentage_error, {}, ), ], diff --git a/tests/unittests/regression/test_minkowski_distance.py b/tests/unittests/regression/test_minkowski_distance.py index 7fbd62d7cc5..e00ccdaf7c4 100644 --- a/tests/unittests/regression/test_minkowski_distance.py +++ b/tests/unittests/regression/test_minkowski_distance.py @@ -13,7 +13,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -22,18 +22,18 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) -def _sk_metric_single_target(preds, target, p): +def _reference_scipy_metric_single_target(preds, target, p): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return scipy_minkowski(sk_preds, sk_target, p=p) -def _sk_metric_multi_target(preds, target, p): +def _reference_scipy_metric_multi_target(preds, target, p): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return scipy_minkowski(sk_preds, sk_target, p=p) @@ -42,8 +42,8 @@ def _sk_metric_multi_target(preds, target, p): @pytest.mark.parametrize( "preds, target, ref_metric", [ - (_single_target_inputs.preds, _single_target_inputs.target, _sk_metric_single_target), - (_multi_target_inputs.preds, _multi_target_inputs.target, _sk_metric_multi_target), + (_single_target_inputs.preds, _single_target_inputs.target, _reference_scipy_metric_single_target), + (_multi_target_inputs.preds, _multi_target_inputs.target, _reference_scipy_metric_multi_target), ], ) @pytest.mark.parametrize("p", [1, 2, 4, 1.5]) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index dd5f35addac..d7ed5b27bfd 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -48,7 +48,7 @@ ) -def _scipy_pearson(preds, target): +def _reference_scipy_pearson(preds, target): if preds.ndim == 2: return [pearsonr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)] return pearsonr(target.numpy(), preds.numpy())[0] @@ -78,14 +78,14 @@ def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp): preds=preds, target=target, metric_class=PearsonCorrCoef, - reference_metric=_scipy_pearson, + reference_metric=_reference_scipy_pearson, metric_args={"num_outputs": num_outputs, "compute_on_cpu": compute_on_cpu}, ) def test_pearson_corrcoef_functional(self, preds, target): """Test functional implementation of metric.""" self.run_functional_metric_test( - preds=preds, target=target, metric_functional=pearson_corrcoef, reference_metric=_scipy_pearson + preds=preds, target=target, metric_functional=pearson_corrcoef, reference_metric=_reference_scipy_pearson ) def test_pearson_corrcoef_differentiability(self, preds, target): diff --git a/tests/unittests/regression/test_r2.py b/tests/unittests/regression/test_r2.py index 8ed4a0ddd8a..adcdc8a2dc8 100644 --- a/tests/unittests/regression/test_r2.py +++ b/tests/unittests/regression/test_r2.py @@ -25,7 +25,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -34,12 +34,12 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) -def _single_target_ref_metric(preds, target, adjusted, multioutput): +def _single_target_ref_wrapper(preds, target, adjusted, multioutput): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) @@ -48,9 +48,9 @@ def _single_target_ref_metric(preds, target, adjusted, multioutput): return r2_score -def _multi_target_ref_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() +def _multi_target_ref_wrapper(preds, target, adjusted, multioutput): + sk_preds = preds.view(-1, NUM_TARGETS).numpy() + sk_target = target.view(-1, NUM_TARGETS).numpy() r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) if adjusted != 0: return 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) @@ -62,8 +62,8 @@ def _multi_target_ref_metric(preds, target, adjusted, multioutput): @pytest.mark.parametrize( "preds, target, ref_metric, num_outputs", [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_metric, 1), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_metric, num_targets), + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper, 1), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper, NUM_TARGETS), ], ) class TestR2Score(MetricTester): diff --git a/tests/unittests/regression/test_rse.py b/tests/unittests/regression/test_rse.py index 886580c84bf..0f127ed796c 100644 --- a/tests/unittests/regression/test_rse.py +++ b/tests/unittests/regression/test_rse.py @@ -26,7 +26,7 @@ seed_all(42) -num_targets = 5 +NUM_TARGETS = 5 _single_target_inputs = _Input( @@ -35,12 +35,12 @@ ) _multi_target_inputs = _Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_TARGETS), ) -def _sk_rse(target, preds, squared): +def _reference_rse(target, preds, squared): mean = np.mean(target, axis=0, keepdims=True) error = target - preds sum_squared_error = np.sum(error * error, axis=0) @@ -52,24 +52,24 @@ def _sk_rse(target, preds, squared): return np.mean(rse) -def _single_target_ref_metric(preds, target, squared): +def _single_target_ref_wrapper(preds, target, squared): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return _sk_rse(sk_target, sk_preds, squared=squared) + return _reference_rse(sk_target, sk_preds, squared=squared) -def _multi_target_ref_metric(preds, target, squared): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return _sk_rse(sk_target, sk_preds, squared=squared) +def _multi_target_ref_wrapper(preds, target, squared): + sk_preds = preds.view(-1, NUM_TARGETS).numpy() + sk_target = target.view(-1, NUM_TARGETS).numpy() + return _reference_rse(sk_target, sk_preds, squared=squared) @pytest.mark.parametrize("squared", [False, True]) @pytest.mark.parametrize( "preds, target, ref_metric, num_outputs", [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_metric, 1), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_metric, num_targets), + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper, 1), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper, NUM_TARGETS), ], ) class TestRelativeSquaredError(MetricTester): diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index af65ed8604f..4f3bd2ab790 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -70,7 +70,7 @@ def test_ranking(preds, target): assert (torch.tensor(scipy_ranking[1]) == tm_ranking[1]).all() -def _scipy_spearman(preds, target): +def _reference_scipy_spearman(preds, target): if preds.ndim == 2: return [spearmanr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)] return spearmanr(target.numpy(), preds.numpy())[0] @@ -100,13 +100,13 @@ def test_spearman_corrcoef(self, preds, target, ddp): preds, target, SpearmanCorrCoef, - _scipy_spearman, + _reference_scipy_spearman, metric_args={"num_outputs": num_outputs}, ) def test_spearman_corrcoef_functional(self, preds, target): """Test functional implementation of metric.""" - self.run_functional_metric_test(preds, target, spearman_corrcoef, _scipy_spearman) + self.run_functional_metric_test(preds, target, spearman_corrcoef, _reference_scipy_spearman) def test_spearman_corrcoef_differentiability(self, preds, target): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" diff --git a/tests/unittests/regression/test_tweedie_deviance.py b/tests/unittests/regression/test_tweedie_deviance.py index ce204a0e5c4..00324dc041f 100644 --- a/tests/unittests/regression/test_tweedie_deviance.py +++ b/tests/unittests/regression/test_tweedie_deviance.py @@ -43,7 +43,7 @@ ) -def _sklearn_deviance(preds: Tensor, targets: Tensor, power: float): +def _reference_sklearn_deviance(preds: Tensor, targets: Tensor, power: float): sk_preds = preds.view(-1).numpy() sk_target = targets.view(-1).numpy() return mean_tweedie_deviance(sk_target, sk_preds, power=power) @@ -69,7 +69,7 @@ def test_deviance_scores_class(self, ddp, preds, target, power): preds, target, TweedieDevianceScore, - partial(_sklearn_deviance, power=power), + partial(_reference_sklearn_deviance, power=power), metric_args={"power": power}, ) @@ -79,7 +79,7 @@ def test_deviance_scores_functional(self, preds, target, power): preds, target, tweedie_deviance_score, - partial(_sklearn_deviance, power=power), + partial(_reference_sklearn_deviance, power=power), metric_args={"power": power}, ) diff --git a/tests/unittests/retrieval/inputs.py b/tests/unittests/retrieval/_inputs.py similarity index 100% rename from tests/unittests/retrieval/inputs.py rename to tests/unittests/retrieval/_inputs.py diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index de93fa24ad0..748c8f993b0 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -24,18 +24,18 @@ from unittests.helpers import seed_all from unittests.helpers.testers import Metric, MetricTester -from unittests.retrieval.inputs import _input_retrieval_scores as _irs -from unittests.retrieval.inputs import _input_retrieval_scores_all_target as _irs_all -from unittests.retrieval.inputs import _input_retrieval_scores_empty as _irs_empty -from unittests.retrieval.inputs import _input_retrieval_scores_extra as _irs_extra -from unittests.retrieval.inputs import _input_retrieval_scores_float_target as _irs_float_tgt -from unittests.retrieval.inputs import _input_retrieval_scores_for_adaptive_k as _irs_adpt_k -from unittests.retrieval.inputs import _input_retrieval_scores_int_target as _irs_int_tgt -from unittests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_bad_sz -from unittests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_bad_sz_fn -from unittests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt -from unittests.retrieval.inputs import _input_retrieval_scores_with_ignore_index as _irs_ii -from unittests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt +from unittests.retrieval._inputs import _input_retrieval_scores as _irs +from unittests.retrieval._inputs import _input_retrieval_scores_all_target as _irs_all +from unittests.retrieval._inputs import _input_retrieval_scores_empty as _irs_empty +from unittests.retrieval._inputs import _input_retrieval_scores_extra as _irs_extra +from unittests.retrieval._inputs import _input_retrieval_scores_float_target as _irs_float_tgt +from unittests.retrieval._inputs import _input_retrieval_scores_for_adaptive_k as _irs_adpt_k +from unittests.retrieval._inputs import _input_retrieval_scores_int_target as _irs_int_tgt +from unittests.retrieval._inputs import _input_retrieval_scores_mismatching_sizes as _irs_bad_sz +from unittests.retrieval._inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_bad_sz_fn +from unittests.retrieval._inputs import _input_retrieval_scores_no_target as _irs_no_tgt +from unittests.retrieval._inputs import _input_retrieval_scores_with_ignore_index as _irs_ii +from unittests.retrieval._inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt seed_all(42) diff --git a/tests/unittests/text/inputs.py b/tests/unittests/text/_inputs.py similarity index 100% rename from tests/unittests/text/inputs.py rename to tests/unittests/text/_inputs.py diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index bffd8735b25..3b2382a8488 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -23,8 +23,8 @@ from typing_extensions import Literal from unittests.helpers import skip_on_connection_issues +from unittests.text._inputs import _inputs_single_reference from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_single_reference if _BERTSCORE_AVAILABLE: from bert_score import score as original_bert_score @@ -43,9 +43,9 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" +@skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -@skip_on_connection_issues() def _reference_bert_score( preds: Sequence[str], target: Sequence[str], diff --git a/tests/unittests/text/test_bleu.py b/tests/unittests/text/test_bleu.py index 53b3ad43d34..3c271cf10e6 100644 --- a/tests/unittests/text/test_bleu.py +++ b/tests/unittests/text/test_bleu.py @@ -20,8 +20,8 @@ from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.text.bleu import BLEUScore +from unittests.text._inputs import _inputs_multiple_references from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_multiple_references # https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction smooth_func = SmoothingFunction().method2 diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index bc9f061079c..34c7a9735a9 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -18,8 +18,8 @@ from torchmetrics.text.cer import CharErrorRate from torchmetrics.utilities.imports import _JIWER_AVAILABLE +from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 if _JIWER_AVAILABLE: from jiwer import cer diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index ba409b8bada..4df8e5d8c22 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -20,8 +20,8 @@ from torchmetrics.text.chrf import CHRFScore from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE +from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references if _SACREBLEU_AVAILABLE: from sacrebleu.metrics import CHRF diff --git a/tests/unittests/text/test_edit.py b/tests/unittests/text/test_edit.py index f3b012e9dd8..a7d4029cef6 100644 --- a/tests/unittests/text/test_edit.py +++ b/tests/unittests/text/test_edit.py @@ -18,8 +18,8 @@ from torchmetrics.functional.text.edit import edit_distance from torchmetrics.text.edit import EditDistance +from unittests.text._inputs import _inputs_single_reference from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_single_reference @pytest.mark.parametrize( diff --git a/tests/unittests/text/test_eed.py b/tests/unittests/text/test_eed.py index a49de9fe507..964df16d3d1 100644 --- a/tests/unittests/text/test_eed.py +++ b/tests/unittests/text/test_eed.py @@ -19,8 +19,8 @@ from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.text.eed import ExtendedEditDistance +from unittests.text._inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references def _reference_rwth_manual(preds, targets) -> Tensor: diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 59c8979e812..1ddd6d9bcfe 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -20,8 +20,8 @@ from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 from unittests.helpers import skip_on_connection_issues +from unittests.text._inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference from unittests.text.helpers import TextTester -from unittests.text.inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference # Small bert model with 2 layers, 2 attention heads and hidden dim of 128 MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2" diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index f53500df937..9ff0823f173 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -19,8 +19,8 @@ from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.helpers import seed_all +from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 if _JIWER_AVAILABLE: from jiwer import compute_measures diff --git a/tests/unittests/text/test_perplexity.py b/tests/unittests/text/test_perplexity.py index 6cd49320e23..3673da47647 100644 --- a/tests/unittests/text/test_perplexity.py +++ b/tests/unittests/text/test_perplexity.py @@ -21,7 +21,7 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from unittests.helpers.testers import MetricTester -from unittests.text.inputs import ( +from unittests.text._inputs import ( MASK_INDEX, _logits_inputs_fp32, _logits_inputs_fp32_with_mask, diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index ad76e8e87ee..8d59929c488 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -25,8 +25,8 @@ from typing_extensions import Literal from unittests.helpers import skip_on_connection_issues +from unittests.text._inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference from unittests.text.helpers import TextTester -from unittests.text.inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference if _ROUGE_SCORE_AVAILABLE: from rouge_score.rouge_scorer import RougeScorer diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index 7d51e1274dd..362fcacf59e 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -21,14 +21,16 @@ from torchmetrics.text.sacre_bleu import SacreBLEUScore from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE +from unittests.text._inputs import _inputs_multiple_references from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_multiple_references if _SACREBLEU_AVAILABLE: from sacrebleu.metrics import BLEU -def _sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool) -> Tensor: +def _reference_sacre_bleu( + preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool +) -> Tensor: sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase) # Sacrebleu expects different format of input targets = [[target[i] for target in targets] for i in range(len(targets[0]))] @@ -50,7 +52,7 @@ class TestSacreBLEUScore(TextTester): def test_bleu_score_class(self, ddp, preds, targets, tokenize, lowercase): """Test class implementation of metric.""" metric_args = {"tokenize": tokenize, "lowercase": lowercase} - original_sacrebleu = partial(_sacrebleu_fn, tokenize=tokenize, lowercase=lowercase) + original_sacrebleu = partial(_reference_sacre_bleu, tokenize=tokenize, lowercase=lowercase) self.run_class_metric_test( ddp=ddp, @@ -64,7 +66,7 @@ def test_bleu_score_class(self, ddp, preds, targets, tokenize, lowercase): def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): """Test functional implementation of metric.""" metric_args = {"tokenize": tokenize, "lowercase": lowercase} - original_sacrebleu = partial(_sacrebleu_fn, tokenize=tokenize, lowercase=lowercase) + original_sacrebleu = partial(_reference_sacre_bleu, tokenize=tokenize, lowercase=lowercase) self.run_functional_metric_test( preds, @@ -114,7 +116,7 @@ def test_tokenize_ja_mecab(): preds = ["これは美しい花です。"] targets = [["これは美しい花です。", "おいしい寿司を食べたい。"]] - assert sacrebleu(preds, targets) == _sacrebleu_fn(preds, targets, tokenize="ja-mecab", lowercase=False) + assert sacrebleu(preds, targets) == _reference_sacre_bleu(preds, targets, tokenize="ja-mecab", lowercase=False) def test_tokenize_ko_mecab(): @@ -123,7 +125,7 @@ def test_tokenize_ko_mecab(): preds = ["이 책은 정말 재미있어요."] targets = [["이 책은 정말 재미있어요.", "고마워요, 너무 도와줘서."]] - assert sacrebleu(preds, targets) == _sacrebleu_fn(preds, targets, tokenize="ko-mecab", lowercase=False) + assert sacrebleu(preds, targets) == _reference_sacre_bleu(preds, targets, tokenize="ko-mecab", lowercase=False) def test_equivalence_of_available_tokenizers_and_annotation(): diff --git a/tests/unittests/text/test_squad.py b/tests/unittests/text/test_squad.py index 51a2e24e38f..5bb4cc0c7fa 100644 --- a/tests/unittests/text/test_squad.py +++ b/tests/unittests/text/test_squad.py @@ -21,7 +21,7 @@ from torchmetrics.text.squad import SQuAD from unittests.helpers.testers import _assert_allclose, _assert_tensor -from unittests.text.inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch +from unittests.text._inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch @pytest.mark.parametrize( diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index 10743397fe2..f6dd90f2c36 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -20,8 +20,8 @@ from torchmetrics.text.ter import TranslationEditRate from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE +from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references if _SACREBLEU_AVAILABLE: from sacrebleu.metrics import TER as SacreTER # noqa: N811 diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index ab143a8372c..bb781f8f0ac 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -18,8 +18,8 @@ from torchmetrics.text.wer import WordErrorRate from torchmetrics.utilities.imports import _JIWER_AVAILABLE +from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 if _JIWER_AVAILABLE: from jiwer import compute_measures @@ -27,7 +27,7 @@ compute_measures: Callable -def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): +def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]): return compute_measures(target, preds)["wer"] @@ -50,7 +50,7 @@ def test_wer_class(self, ddp, preds, targets): preds=preds, targets=targets, metric_class=WordErrorRate, - reference_metric=_compute_wer_metric_jiwer, + reference_metric=_reference_jiwer_wer, ) def test_wer_functional(self, preds, targets): @@ -59,7 +59,7 @@ def test_wer_functional(self, preds, targets): preds, targets, metric_functional=word_error_rate, - reference_metric=_compute_wer_metric_jiwer, + reference_metric=_reference_jiwer_wer, ) def test_wer_differentiability(self, preds, targets): diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index dc686c1e979..9b88866071a 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -19,8 +19,8 @@ from torchmetrics.text.wil import WordInfoLost from torchmetrics.utilities.imports import _JIWER_AVAILABLE +from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 def _reference_jiwer_wil(preds: Union[str, List[str]], target: Union[str, List[str]]): diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index 01620d9376b..c6ce8a89a8e 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -19,8 +19,8 @@ from torchmetrics.text.wip import WordInfoPreserved from torchmetrics.utilities.imports import _JIWER_AVAILABLE +from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -from unittests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]): From a8c11b4c7c1a19796370c97acf1d9e4817be5eb2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 28 Feb 2024 10:24:49 +0100 Subject: [PATCH 04/10] tests: prefer cache for missing config (#2414) move import inside ref metric --- .github/actions/pull-caches/action.yml | 2 +- .github/actions/push-caches/action.yml | 2 +- .gitignore | 2 +- src/torchmetrics/utilities/imports.py | 4 ---- .../classification/test_group_fairness.py | 2 -- tests/unittests/image/test_fid.py | 8 ++++---- tests/unittests/image/test_inception.py | 8 ++++---- tests/unittests/image/test_kid.py | 10 +++++----- tests/unittests/image/test_lpips.py | 13 +++++++------ tests/unittests/image/test_mifid.py | 6 +++--- .../image/test_perceptual_path_length.py | 8 ++++---- tests/unittests/multimodal/test_clip_iqa.py | 10 +++++----- tests/unittests/nominal/test_cramers.py | 12 ++++++------ tests/unittests/nominal/test_pearson.py | 9 ++++----- tests/unittests/nominal/test_theils_u.py | 10 +++++----- tests/unittests/nominal/test_tschuprows.py | 9 ++++----- tests/unittests/text/test_bertscore.py | 14 ++++++-------- tests/unittests/text/test_cer.py | 15 ++++++--------- tests/unittests/text/test_chrf.py | 10 +++++----- tests/unittests/text/test_infolm.py | 6 +++--- tests/unittests/text/test_mer.py | 13 +++++-------- tests/unittests/text/test_rouge.py | 2 +- tests/unittests/text/test_sacre_bleu.py | 10 +++++----- tests/unittests/text/test_ter.py | 12 ++++++------ tests/unittests/text/test_wer.py | 14 ++++++-------- tests/unittests/text/test_wil.py | 8 +++++--- tests/unittests/text/test_wip.py | 8 +++++--- 27 files changed, 107 insertions(+), 120 deletions(-) diff --git a/.github/actions/pull-caches/action.yml b/.github/actions/pull-caches/action.yml index a5cf7cafe2d..fec93a58bde 100644 --- a/.github/actions/pull-caches/action.yml +++ b/.github/actions/pull-caches/action.yml @@ -90,5 +90,5 @@ runs: - name: Restored References continue-on-error: true - run: ls -lh tests/_cache-references/ + run: py-tree tests/_cache-references/ --show_hidden shell: bash diff --git a/.github/actions/push-caches/action.yml b/.github/actions/push-caches/action.yml index da757f09b5d..8f5db36b6dd 100644 --- a/.github/actions/push-caches/action.yml +++ b/.github/actions/push-caches/action.yml @@ -99,5 +99,5 @@ runs: key: cache-references - name: Post References - run: ls -lh tests/_cache-references/ + run: py-tree tests/_cache-references/ --show_hidden shell: bash diff --git a/.gitignore b/.gitignore index 6f45b493e3c..cbe31a9b316 100644 --- a/.gitignore +++ b/.gitignore @@ -40,7 +40,7 @@ pip-delete-this-directory.txt # Unit test / coverage reports tests/_data/ data.zip -tests/_reference-cache/ +tests/_cache-references/ htmlcov/ .coverage .coverage.* diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 085269dd4d9..6e80411f5c1 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -17,10 +17,8 @@ import sys from lightning_utilities.core.imports import RequirementCache -from packaging.version import Version, parse _PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" -_PYTHON_LOWER_3_8 = parse(_PYTHON_VERSION) < Version("3.8") _TORCH_LOWER_2_0 = RequirementCache("torch<2.0.0") _TORCH_GREATER_EQUAL_1_11 = RequirementCache("torch>=1.11.0") _TORCH_GREATER_EQUAL_1_12 = RequirementCache("torch>=1.12.0") @@ -29,7 +27,6 @@ _TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0") _TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0") -_JIWER_AVAILABLE = RequirementCache("jiwer") _NLTK_AVAILABLE = RequirementCache("nltk") _ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score") _BERTSCORE_AVAILABLE = RequirementCache("bert_score") @@ -49,7 +46,6 @@ _GAMMATONE_AVAILABLE = RequirementCache("gammatone") _TORCHAUDIO_AVAILABLE = RequirementCache("torchaudio") _TORCHAUDIO_GREATER_EQUAL_0_10 = RequirementCache("torchaudio>=0.10.0") -_SACREBLEU_AVAILABLE = RequirementCache("sacrebleu") _REGEX_AVAILABLE = RequirementCache("regex") _PYSTOI_AVAILABLE = RequirementCache("pystoi") _FAST_BSS_EVAL_AVAILABLE = RequirementCache("fast_bss_eval") diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 811e2a55ab9..4d76b9301dd 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -26,7 +26,6 @@ from torchmetrics import Metric from torchmetrics.classification.group_fairness import BinaryFairness from torchmetrics.functional.classification.group_fairness import binary_fairness -from torchmetrics.utilities.imports import _PYTHON_LOWER_3_8 from unittests import THRESHOLD from unittests.classification._inputs import _group_cases @@ -222,7 +221,6 @@ def run_precision_test_gpu( @mock.patch("unittests.helpers.testers._assert_tensor", _assert_tensor) @mock.patch("unittests.helpers.testers._assert_allclose", _assert_allclose) -@pytest.mark.skipif(_PYTHON_LOWER_3_8, reason="`TestBinaryFairness` requires `python>=3.8`.") @pytest.mark.parametrize("inputs", _group_cases) class TestBinaryFairness(BinaryFairnessTester): """Test class for `BinaryFairness` metric.""" diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 83b243200da..252f0d0ebba 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -34,7 +34,7 @@ def test_no_train_network_missing_torch_fidelity(): NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"]) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -52,7 +52,7 @@ def forward(self, x): assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen" -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_fid_pickle(): """Assert that we can initialize the metric and pickle it.""" metric = FrechetInceptionDistance() @@ -80,7 +80,7 @@ def test_fid_raises_errors_and_warnings(): _ = FrechetInceptionDistance(feature=[1, 2]) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" @@ -111,7 +111,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("equal_size", [False, True]) def test_compare_fid(tmpdir, equal_size, feature=768): """Check that the hole pipeline give the same result as torch-fidelity.""" diff --git a/tests/unittests/image/test_inception.py b/tests/unittests/image/test_inception.py index 552180cbbcc..627e6a4a57a 100644 --- a/tests/unittests/image/test_inception.py +++ b/tests/unittests/image/test_inception.py @@ -24,7 +24,7 @@ torch.manual_seed(42) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -44,7 +44,7 @@ def forward(self, x): ), "InceptionScore metric was changed to training mode which should not happen" -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_is_pickle(): """Assert that we can initialize the metric and pickle it.""" metric = InceptionScore() @@ -79,7 +79,7 @@ def test_is_raises_errors_and_warnings(): InceptionScore(feature=[1, 2]) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_is_update_compute(): """Test that inception score works as expected.""" metric = InceptionScore() @@ -105,7 +105,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("compute_on_cpu", [True, False]) def test_compare_is(tmpdir, compute_on_cpu): """Check that the hole pipeline give the same result as torch-fidelity.""" diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index 34d223e24af..a754768003c 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -24,7 +24,7 @@ torch.manual_seed(42) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -42,7 +42,7 @@ def forward(self, x): assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen" -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_kid_pickle(): """Assert that we can initialize the metric and pickle it.""" metric = KernelInceptionDistance() @@ -83,7 +83,7 @@ def test_kid_raises_errors_and_warnings(): m.compute() -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_kid_extra_parameters(): """Test that the different input arguments raises expected errors if wrong.""" with pytest.raises(ValueError, match="Argument `subsets` expected to be integer larger than 0"): @@ -102,7 +102,7 @@ def test_kid_extra_parameters(): KernelInceptionDistance(coef=-1) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_kid_same_input(feature): """Test that the metric works.""" @@ -132,7 +132,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_compare_kid(tmpdir, feature=2048): """Check that the hole pipeline give the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 0a7171ab996..026c2b91770 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -16,11 +16,10 @@ import pytest import torch -from lpips import LPIPS as LPIPS_reference # noqa: N811 from torch import Tensor from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCHVISION_AVAILABLE +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -43,7 +42,12 @@ def _reference_lpips( img1: Tensor, img2: Tensor, net_type: str, normalize: bool = False, reduction: str = "mean" ) -> Tensor: """Comparison function for tm implementation.""" - ref = LPIPS_reference(net=net_type) + try: + from lpips import LPIPS + except ImportError: + pytest.skip("test requires lpips package to be installed") + + ref = LPIPS(net=net_type) res = ref(img1, img2, normalize=normalize).detach().cpu().numpy() if reduction == "mean": return res.mean() @@ -51,7 +55,6 @@ def _reference_lpips( @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed") -@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") class TestLPIPS(MetricTester): """Test class for `LearnedPerceptualImagePatchSimilarity` metric.""" @@ -109,7 +112,6 @@ def test_normalize_arg(normalize): @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed") -@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") def test_error_on_wrong_init(): """Test class raises the expected errors.""" with pytest.raises(ValueError, match="Argument `net_type` must be one .*"): @@ -120,7 +122,6 @@ def test_error_on_wrong_init(): @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed") -@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") @pytest.mark.parametrize( ("inp1", "inp2"), [ diff --git a/tests/unittests/image/test_mifid.py b/tests/unittests/image/test_mifid.py index d5bdb95cf68..ae44982b350 100644 --- a/tests/unittests/image/test_mifid.py +++ b/tests/unittests/image/test_mifid.py @@ -98,7 +98,7 @@ def calculate_mifid(m1, s1, features1, m2, s2, features2): return fid_private / (distance_private_thresholded + 1e-15) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -139,7 +139,7 @@ def test_mifid_raises_errors_and_warnings(): _ = MemorizationInformedFrechetInceptionDistance(cosine_distance_eps=1.1) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" @@ -157,7 +157,7 @@ def test_fid_same_input(feature): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("equal_size", [False, True]) def test_compare_mifid(equal_size): """Check that our implementation of MIFID is correct by comparing it to the original implementation.""" diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index dfdd5cfde96..1eb486c6ce3 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -29,7 +29,7 @@ seed_all(42) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("interpolation_method", ["lerp", "slerp_any", "slerp_unit"]) def test_interpolation_methods(interpolation_method): """Test that interpolation method works as expected.""" @@ -41,7 +41,7 @@ def test_interpolation_methods(interpolation_method): assert torch.allclose(res1, res2) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @skip_on_running_out_of_memory() def test_sim_net(): """Check that the similarity network is the same as the one used in torch_fidelity.""" @@ -100,7 +100,7 @@ def sample(self, num_samples): return torch.randn(num_samples, self.z_size) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize( ("argument", "match"), [ @@ -174,7 +174,7 @@ def test_raises_error_on_wrong_generator(generator, errortype, match): ppl.update(generator=generator) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @skip_on_running_out_of_memory() def test_compare(): diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index c7057226759..314ff0013b8 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -71,7 +71,7 @@ def _reference_clip_iqa(preds, target, reduce=False): return res.sum() if reduce else res -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") class TestCLIPIQA(MetricTester): """Test clip iqa metric.""" @@ -104,7 +104,7 @@ def test_clip_iqa_functional(self, shapes): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") @pytest.mark.skipif(not os.path.isfile(_SAMPLE_IMAGE), reason="test image not found") def test_for_correctness_sample_images(): @@ -121,7 +121,7 @@ def test_for_correctness_sample_images(): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") @pytest.mark.parametrize( "model", @@ -148,7 +148,7 @@ def test_other_models(model): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") @pytest.mark.parametrize( "prompts", @@ -200,7 +200,7 @@ def test_prompt(prompts): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") def test_plot_method(): """Test the plot method of CLIPScore separately in this file due to the skipping conditions.""" diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index 4cebac73e05..42b735ef510 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator from functools import partial import pytest import torch -from dython.nominal import cramers_v as dython_cramers_v -from lightning_utilities.core.imports import compare_version from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.nominal.cramers import CramersV @@ -63,10 +60,15 @@ def cramers_matrix_input(): def _reference_dython_cramers_v(preds, target, bias_correction, nan_strategy, nan_replace_value): + try: + from dython.nominal import cramers_v + except ImportError: + pytest.skip("This test requires `dython` package to be installed.") + preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target - v = dython_cramers_v( + v = cramers_v( preds.numpy(), target.numpy(), bias_correction=bias_correction, @@ -87,7 +89,6 @@ def _dython_cramers_v_matrix(matrix, bias_correction, nan_strategy, nan_replace_ return cramers_v_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -161,7 +162,6 @@ def test_cramers_v_differentiability(self, preds, target, bias_correction, nan_s ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize("bias_correction", [False, True]) @pytest.mark.parametrize(("nan_strategy", "nan_replace_value"), [("replace", 1.0), ("drop", None)]) def test_cramers_v_matrix(cramers_matrix_input, bias_correction, nan_strategy, nan_replace_value): diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index 44bf1c0e415..5bec1cd8121 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator import pandas as pd import pytest import torch -from lightning_utilities.core.imports import compare_version -from scipy.stats.contingency import association from torchmetrics.functional.nominal.pearson import ( pearsons_contingency_coefficient, pearsons_contingency_coefficient_matrix, @@ -56,6 +53,10 @@ def pearson_matrix_input(): def _reference_pd_pearsons_t(preds, target): + try: + from scipy.stats.contingency import association + except ImportError: + pytest.skip("test requires scipy package to be installed") preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target preds, target = preds.numpy().astype(int), target.numpy().astype(int) @@ -74,7 +75,6 @@ def _reference_pd_pearsons_t_matrix(matrix): return pearsons_t_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -118,7 +118,6 @@ def test_pearsons_t_differentiability(self, preds, target): ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") def test_pearsons_contingency_coefficient_matrix(pearson_matrix_input): """Test matrix version of metric works as expected.""" tm_score = pearsons_contingency_coefficient_matrix(pearson_matrix_input) diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index b7ae4b29507..c06c6b9bcd2 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator from functools import partial import pytest import torch -from dython.nominal import theils_u as dython_theils_u -from lightning_utilities.core.imports import compare_version from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix from torchmetrics.nominal import TheilsU @@ -63,6 +60,11 @@ def theils_u_matrix_input(): def _reference_dython_theils_u(preds, target, nan_strategy, nan_replace_value): + try: + from dython.nominal import theils_u as dython_theils_u + except ImportError: + pytest.skip("Test requires `dython` package to be installed.") + preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target @@ -85,7 +87,6 @@ def _reference_dython_theils_u_matrix(matrix, nan_strategy, nan_replace_value): return theils_u_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -153,7 +154,6 @@ def test_theils_u_differentiability(self, preds, target, nan_strategy, nan_repla ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize(("nan_strategy", "nan_replace_value"), [("replace", 1.0), ("drop", None)]) def test_theils_u_matrix(theils_u_matrix_input, nan_strategy, nan_replace_value): """Test matrix version of metric works as expected.""" diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index 48102ac6f34..91798d88d82 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator import pandas as pd import pytest import torch -from lightning_utilities.core.imports import compare_version -from scipy.stats.contingency import association from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix from torchmetrics.nominal.tschuprows import TschuprowsT @@ -53,6 +50,10 @@ def tschuprows_matrix_input(): def _reference_pd_tschuprows_t(preds, target): + try: + from scipy.stats.contingency import association + except ImportError: + pytest.skip("test requires scipy package to be installed") preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target preds, target = preds.numpy().astype(int), target.numpy().astype(int) @@ -71,7 +72,6 @@ def _reference_pd_tschuprows_t_matrix(matrix): return tschuprows_t_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -120,7 +120,6 @@ def test_tschuprows_t_differentiability(self, preds, target): ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") def test_tschuprows_t_matrix(tschuprows_matrix_input): """Test matrix version of metric works as expected.""" tm_score = tschuprows_t_matrix(tschuprows_matrix_input, bias_correction=False) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 3b2382a8488..b651ecfddb1 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -19,18 +19,13 @@ from torch import Tensor from torchmetrics.functional.text.bert import bert_score from torchmetrics.text.bert import BERTScore -from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4 +from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 from typing_extensions import Literal from unittests.helpers import skip_on_connection_issues from unittests.text._inputs import _inputs_single_reference from unittests.text.helpers import TextTester -if _BERTSCORE_AVAILABLE: - from bert_score import score as original_bert_score -else: - original_bert_score = None - _METRIC_KEY_TO_IDX = { "precision": 0, "recall": 1, @@ -45,7 +40,6 @@ @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") def _reference_bert_score( preds: Sequence[str], target: Sequence[str], @@ -55,6 +49,11 @@ def _reference_bert_score( rescale_with_baseline: bool, metric_key: Literal["f1", "precision", "recall"], ) -> Tensor: + try: + from bert_score import score as original_bert_score + except ImportError: + pytest.skip("test requires bert_score package to be installed.") + score_tuple = original_bert_score( preds, target, @@ -88,7 +87,6 @@ def _reference_bert_score( [(_inputs_single_reference.preds, _inputs_single_reference.target)], ) @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") class TestBERTScore(TextTester): """Tests for BERTScore.""" diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index 34c7a9735a9..ab09f3e5334 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -11,28 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import List, Union import pytest from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.text.cer import CharErrorRate -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -if _JIWER_AVAILABLE: - from jiwer import cer - -else: - compute_measures = Callable - def _reference_jiwer_cer(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import cer + except ImportError: + pytest.skip("test requires jiwer package to be installed.") + return cer(target, preds) -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index 4df8e5d8c22..6dd328d2c77 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -18,14 +18,10 @@ from torch import Tensor, tensor from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.text.chrf import CHRFScore -from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -if _SACREBLEU_AVAILABLE: - from sacrebleu.metrics import CHRF - def _reference_sacrebleu_chrf( preds: Sequence[str], @@ -35,6 +31,11 @@ def _reference_sacrebleu_chrf( lowercase: bool, whitespace: bool, ) -> Tensor: + try: + from sacrebleu import CHRF + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + sacrebleu_chrf = CHRF( char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace, eps_smoothing=True ) @@ -59,7 +60,6 @@ def _reference_sacrebleu_chrf( ["preds", "targets"], [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) -@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestCHRFScore(TextTester): """Test class for `CHRFScore` metric.""" diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 1ddd6d9bcfe..d8611695ff3 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -36,10 +36,10 @@ def _reference_infolm_score(preds, target, model_name, information_measure, idf, https://github.com/stancld/infolm-docker. """ - if model_name != "google/bert_uncased_L-2_H-128_A-2": + allowed_model = "google/bert_uncased_L-2_H-128_A-2" + if model_name != allowed_model: raise ValueError( - "`model_name` is expected to be 'google/bert_uncased_L-2_H-128_A-2' as this model was used for the result " - "generation." + f"`model_name` is expected to be '{allowed_model}' as this model was used for the result generation." ) precomputed_result = { "kl_divergence": torch.tensor([-3.2250, -0.1784, -0.1784, -2.2182]), diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index 9ff0823f173..e6f5222c3b1 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -11,30 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import List, Union import pytest from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.text.mer import MatchErrorRate -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.helpers import seed_all from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -if _JIWER_AVAILABLE: - from jiwer import compute_measures -else: - compute_measures: Callable - seed_all(42) def _reference_jiwer_mer(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import compute_measures + except ImportError: + pytest.skip("test requires jiwer package to be installed") return compute_measures(target, preds)["mer"] -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index 8d59929c488..c9eec8a055a 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -91,7 +91,7 @@ def _reference_rouge_score( return torch.tensor(rs_result, dtype=torch.float) -@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") +@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="metric requires nltk") @pytest.mark.parametrize( ["pl_rouge_metric_key", "use_stemmer"], [ diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index 362fcacf59e..d74d032597d 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -19,18 +19,19 @@ from torch import Tensor, tensor from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore -from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE from unittests.text._inputs import _inputs_multiple_references from unittests.text.helpers import TextTester -if _SACREBLEU_AVAILABLE: - from sacrebleu.metrics import BLEU - def _reference_sacre_bleu( preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool ) -> Tensor: + try: + from sacrebleu.metrics import BLEU + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase) # Sacrebleu expects different format of input targets = [[target[i] for target in targets] for i in range(len(targets[0]))] @@ -44,7 +45,6 @@ def _reference_sacre_bleu( ) @pytest.mark.parametrize(["lowercase"], [(False,), (True,)]) @pytest.mark.parametrize("tokenize", AVAILABLE_TOKENIZERS) -@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestSacreBLEUScore(TextTester): """Test class for `SacreBLEUScore` metric.""" diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index f6dd90f2c36..eb63451cf36 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -18,14 +18,10 @@ from torch import Tensor, tensor from torchmetrics.functional.text.ter import translation_edit_rate from torchmetrics.text.ter import TranslationEditRate -from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -if _SACREBLEU_AVAILABLE: - from sacrebleu.metrics import TER as SacreTER # noqa: N811 - def _reference_sacrebleu_ter( preds: Sequence[str], @@ -35,7 +31,12 @@ def _reference_sacrebleu_ter( asian_support: bool, case_sensitive: bool, ) -> Tensor: - sacrebleu_ter = SacreTER( + try: + from sacrebleu.metrics import TER + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + + sacrebleu_ter = TER( normalized=normalized, no_punct=no_punct, asian_support=asian_support, case_sensitive=case_sensitive ) # Sacrebleu CHRF expects different format of input @@ -59,7 +60,6 @@ def _reference_sacrebleu_ter( ["preds", "targets"], [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) -@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestTER(TextTester): """Test class for `TranslationEditRate` metric.""" diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index bb781f8f0ac..6aee783d411 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -11,27 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import List, Union import pytest from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.text.wer import WordErrorRate -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -if _JIWER_AVAILABLE: - from jiwer import compute_measures -else: - compute_measures: Callable - def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import compute_measures + except ImportError: + pytest.skip("test requires jiwer package to be installed") + return compute_measures(target, preds)["wer"] -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 9b88866071a..08ecad16284 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -14,20 +14,22 @@ from typing import List, Union import pytest -from jiwer import wil from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.text.wil import WordInfoLost -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester def _reference_jiwer_wil(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import wil + except ImportError: + pytest.skip("test requires jiwer package to be installed") + return wil(target, preds) -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index c6ce8a89a8e..1900f7182b2 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -14,20 +14,22 @@ from typing import List, Union import pytest -from jiwer import wip from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.text.wip import WordInfoPreserved -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import wip + except ImportError: + pytest.skip("test requires jiwer package to be installed") + return wip(target, preds) -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ From 4230cfef3d2020fffff873565acea01ad883d3e4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:05:55 +0100 Subject: [PATCH 05/10] ci: adding testing with mac M1 (#2385) * adding tests for mac M1 * RelativeAverageSpectralError Expected: tensor(5114.6641) Got: tensor(5114.6636) * mecab * Apply suggestions from code review --- .github/workflows/ci-integrate.yml | 1 + .github/workflows/ci-tests.yml | 7 +++++++ Makefile | 3 ++- src/torchmetrics/functional/image/_deprecated.py | 2 +- src/torchmetrics/functional/image/rase.py | 2 +- src/torchmetrics/image/_deprecated.py | 2 +- src/torchmetrics/image/rase.py | 2 +- tests/unittests/audio/test_pesq.py | 8 ++++---- 8 files changed, 18 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci-integrate.yml b/.github/workflows/ci-integrate.yml index 245ecc08ac5..a632489c6a7 100644 --- a/.github/workflows/ci-integrate.yml +++ b/.github/workflows/ci-integrate.yml @@ -34,6 +34,7 @@ jobs: - { python-version: "3.10", os: "windows" } # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192 include: - { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" } + - { python-version: "3.10", requires: "latest", os: "macOS-14" } # M1 machine env: PYTORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 755b78b2807..89e3a04a5bc 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -42,13 +42,19 @@ jobs: - "2.1.2" - "2.2.1" include: + # cover additional python nad PR combinations - { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" } - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" } - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.1" } - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.2.1" } + # standard mac machine, not the M1 - { os: "macOS-12", python-version: "3.8", pytorch-version: "1.13.1" } - { os: "macOS-12", python-version: "3.10", pytorch-version: "2.0.1" } - { os: "macOS-12", python-version: "3.11", pytorch-version: "2.2.1" } + # using the ARM based M1 machine + - { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" } + - { os: "macOS-14", python-version: "3.11", pytorch-version: "2.2.1" } + # some windows - { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" } - { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" } - { os: "windows-2022", python-version: "3.11", pytorch-version: "2.2.1" } @@ -75,6 +81,7 @@ jobs: if: ${{ runner.os == 'macOS' }} run: | echo 'UNITTEST_TIMEOUT=--timeout=75' >> $GITHUB_ENV + brew install mecab # https://github.com/coqui-ai/TTS/issues/1533#issuecomment-1338662303 brew install gcc libomp ffmpeg # https://github.com/pytorch/pytorch/issues/20030 - name: Setup Linux if: ${{ runner.os == 'Linux' }} diff --git a/Makefile b/Makefile index 9e4b7029b90..6f25ea84a6b 100644 --- a/Makefile +++ b/Makefile @@ -36,5 +36,6 @@ env: pip install -e . -U -r requirements/_devel.txt data: - python -c "from urllib.request import urlretrieve ; urlretrieve('https://pl-public-data.s3.amazonaws.com/metrics/data.zip', 'data.zip')" + pip install -q wget + python -m wget https://pl-public-data.s3.amazonaws.com/metrics/data.zip unzip -o data.zip -d ./tests diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index 0b46b25d32a..d0649ab501f 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -110,7 +110,7 @@ def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: >>> preds = torch.rand(4, 3, 16, 16, generator=gen) >>> target = torch.rand(4, 3, 16, 16, generator=gen) >>> _relative_average_spectral_error(preds, target) - tensor(5114.6641) + tensor(5114.66...) """ _deprecated_root_import_func("relative_average_spectral_error", "image") diff --git a/src/torchmetrics/functional/image/rase.py b/src/torchmetrics/functional/image/rase.py index bd561f4a479..54d20c6eee0 100644 --- a/src/torchmetrics/functional/image/rase.py +++ b/src/torchmetrics/functional/image/rase.py @@ -85,7 +85,7 @@ def relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: >>> preds = torch.rand(4, 3, 16, 16) >>> target = torch.rand(4, 3, 16, 16) >>> relative_average_spectral_error(preds, target) - tensor(5114.6641) + tensor(5114.66...) Raises: ValueError: If ``window_size`` is not a positive integer. diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index 7aa7a63743d..bad0457f1f0 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -109,7 +109,7 @@ class _RelativeAverageSpectralError(RelativeAverageSpectralError): >>> target = torch.rand(4, 3, 16, 16) >>> rase = _RelativeAverageSpectralError() >>> rase(preds, target) - tensor(5114.6641) + tensor(5114.66...) """ diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index 297fb6c80c7..c422762eb68 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -53,7 +53,7 @@ class RelativeAverageSpectralError(Metric): >>> target = torch.rand(4, 3, 16, 16) >>> rase = RelativeAverageSpectralError() >>> rase(preds, target) - tensor(5114.6641) + tensor(5114.66...) Raises: ValueError: If ``window_size`` is not a positive integer. diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index c10cfef6568..348f99c13d6 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -130,7 +130,7 @@ def test_on_real_audio(): """Test that metric works as expected on real audio signals.""" rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH) rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB) - pesq = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "wb") - assert pesq == 1.0832337141036987 - pesq = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "nb") - assert pesq == 1.6072081327438354 + pesq_score = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "wb") + assert torch.allclose(pesq_score, torch.tensor(1.0832337141036987), atol=1e-4) + pesq_score = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "nb") + assert torch.allclose(pesq_score, torch.tensor(1.6072081327438354), atol=1e-4) From 043df7b873e09c13b28b3700c27aa6da3273bcd3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 2 Mar 2024 09:22:55 +0100 Subject: [PATCH 06/10] build(deps): bump pypa/gh-action-pypi-publish from 1.8.11 to 1.8.12 (#2419) Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.8.11 to 1.8.12. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.8.11...v1.8.12) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/publish-pkg.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-pkg.yml b/.github/workflows/publish-pkg.yml index c556bfbe1c7..48c7c974e70 100644 --- a/.github/workflows/publish-pkg.yml +++ b/.github/workflows/publish-pkg.yml @@ -67,7 +67,7 @@ jobs: - run: ls -lh dist/ # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.8.11 + uses: pypa/gh-action-pypi-publish@v1.8.12 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -94,7 +94,7 @@ jobs: path: dist - run: ls -lh dist/ - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.8.11 + uses: pypa/gh-action-pypi-publish@v1.8.12 with: user: __token__ password: ${{ secrets.pypi_password }} From 14741e2717d9c5b6fbb3a375f579a08e1b6afe91 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Mar 2024 07:22:29 +0100 Subject: [PATCH 07/10] build(deps): update huggingface-hub requirement from <0.21 to <0.22 in /requirements (#2420) * build(deps): update huggingface-hub requirement in /requirements Updates the requirements on [huggingface-hub](https://github.com/huggingface/huggingface_hub) to permit the latest version. - [Release notes](https://github.com/huggingface/huggingface_hub/releases) - [Commits](https://github.com/huggingface/huggingface_hub/compare/v0.0.1...v0.21.3) --- updated-dependencies: - dependency-name: huggingface-hub dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Drop note --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- requirements/text_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/text_test.txt b/requirements/text_test.txt index 7e76e9993c4..d670edf5832 100644 --- a/requirements/text_test.txt +++ b/requirements/text_test.txt @@ -4,5 +4,5 @@ jiwer >=2.3.0, <3.1.0 rouge-score >0.1.0, <=0.1.2 bert_score ==0.3.13 -huggingface-hub <0.21 # hotfix, failing SDR for latest PT 1.11 +huggingface-hub <0.22 sacrebleu >=2.3.0, <2.5.0 From 2c2316e8089d79ed28dc11ccbf06b6fdad4d7e8c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Mar 2024 07:51:23 +0100 Subject: [PATCH 08/10] build(deps): bump pytest-doctestplus from 1.1.0 to 1.2.0 in /requirements (#2421) build(deps): bump pytest-doctestplus in /requirements Bumps [pytest-doctestplus](https://github.com/scientific-python/pytest-doctestplus) from 1.1.0 to 1.2.0. - [Release notes](https://github.com/scientific-python/pytest-doctestplus/releases) - [Changelog](https://github.com/scientific-python/pytest-doctestplus/blob/main/CHANGES.rst) - [Commits](https://github.com/scientific-python/pytest-doctestplus/compare/v1.1.0...v1.2.0) --- updated-dependencies: - dependency-name: pytest-doctestplus dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/_doctest.txt | 2 +- requirements/_tests.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/_doctest.txt b/requirements/_doctest.txt index e2247199043..ee8fd6c2e1e 100644 --- a/requirements/_doctest.txt +++ b/requirements/_doctest.txt @@ -2,5 +2,5 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment pytest >=6.0.0, <7.5.0 -pytest-doctestplus >=0.9.0, <=1.1.0 +pytest-doctestplus >=0.9.0, <=1.2.0 pytest-rerunfailures >=10.0, <14.0 diff --git a/requirements/_tests.txt b/requirements/_tests.txt index a30c04a3135..0f39bf8d646 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -4,7 +4,7 @@ coverage ==7.4.3 pytest ==7.4.4 pytest-cov ==4.1.0 -pytest-doctestplus ==1.1.0 +pytest-doctestplus ==1.2.0 pytest-rerunfailures ==13.0 pytest-timeout ==2.2.0 pytest-xdist ==3.5.0 From 1951a06fc914a26152e635f62aa8b32399d4c700 Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 5 Mar 2024 17:17:21 +0800 Subject: [PATCH 09/10] fix: `MetricCollection` did not copy the inner state of the metric in `ClasswiseWrapper` when computing group metrics (#2390) * fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics Issue Link: https://github.com/Lightning-AI/torchmetrics/issues/2389 * fix: set _persistent and _reductions be same as internal metric * test: check metric state_dict wrapped in `ClasswiseWrapper` --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/wrappers/classwise.py | 23 +++++++++++++ tests/unittests/bases/test_collections.py | 40 ++++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 3c8d6621bc2..698d0f51848 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor @@ -20,6 +21,9 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric +if typing.TYPE_CHECKING: + from torch.nn import Module + if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["ClasswiseWrapper.plot"] @@ -209,3 +213,22 @@ def plot( """ return self._plot(val, ax) + + def __getattr__(self, name: str) -> Union[Tensor, "Module"]: + """Get attribute from classwise wrapper.""" + # return state from self.metric + if name in ["tp", "fp", "fn", "tn"]: + return getattr(self.metric, name) + + return super().__getattr__(name) + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute to classwise wrapper.""" + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions + if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: + # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + setattr(self.metric, name, value) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 9e4ac4a5897..16c95fc879a 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -17,7 +17,7 @@ import pytest import torch -from torchmetrics import Metric, MetricCollection +from torchmetrics import ClasswiseWrapper, Metric, MetricCollection from torchmetrics.classification import ( BinaryAccuracy, MulticlassAccuracy, @@ -540,6 +540,44 @@ def test_compute_group_define_by_user(): assert m.compute() +def test_classwise_wrapper_compute_group(): + """Check that user can provide compute groups.""" + classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy") + classwise_recall = ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall") + classwise_precision = ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision") + + m = MetricCollection( + { + "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy"), + "recall": ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall"), + "precision": ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision"), + }, + compute_groups=[["accuracy", "recall", "precision"]], + ) + + # Check that we are not going to check the groups in the first update + assert m._groups_checked + assert m.compute_groups == {0: ["accuracy", "recall", "precision"]} + + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + + expected = { + **classwise_accuracy(preds, target), + **classwise_recall(preds, target), + **classwise_precision(preds, target), + } + + m.update(preds, target) + res = m.compute() + + for key in expected: + assert torch.allclose(res[key], expected[key]) + + # check metric state_dict + m.state_dict() + + def test_compute_on_different_dtype(): """Check that extraction of compute groups are robust towards difference in dtype.""" m = MetricCollection([ From 9d76f3f9f1b0f067869e67f0bd426035a332b28c Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 5 Mar 2024 22:09:23 +0800 Subject: [PATCH 10/10] Feat: make `__getattr__` and `__setattr__` of ClasswiseWrapper more general (#2424) * fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics Issue Link: https://github.com/Lightning-AI/torchmetrics/issues/2389 * fix: set _persistent and _reductions be same as internal metric * test: check metric state_dict wrapped in `ClasswiseWrapper` * refactor: make __getattr__ and __setattr__ of ClasswiseWrapper more general * chlog --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka --- CHANGELOG.md | 2 +- src/torchmetrics/wrappers/classwise.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2aa69ba29..f970dbc9df2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) ### Deprecated diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 698d0f51848..0920118c919 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -216,19 +216,19 @@ def plot( def __getattr__(self, name: str) -> Union[Tensor, "Module"]: """Get attribute from classwise wrapper.""" - # return state from self.metric - if name in ["tp", "fp", "fn", "tn"]: - return getattr(self.metric, name) + if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__): + # we need this to prevent from infinite getattribute loop. + return super().__getattr__(name) - return super().__getattr__(name) + return getattr(self.metric, name) def __setattr__(self, name: str, value: Any) -> None: """Set attribute to classwise wrapper.""" - super().__setattr__(name, value) - if name == "metric": - self._defaults = self.metric._defaults - self._persistent = self.metric._persistent - self._reductions = self.metric._reductions - if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: - # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + if hasattr(self, "metric") and name in self.metric._defaults: setattr(self.metric, name, value) + else: + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions