diff --git a/CHANGELOG.md b/CHANGELOG.md index 27e5ac4c9f9..3dee4d9252f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895)) + + - Fixed bug related to empty predictions for `IntersectionOverUnion` metric ([#1892](https://github.com/Lightning-AI/torchmetrics/pull/1892)) diff --git a/requirements/doctest.txt b/requirements/doctest.txt index ca514d5f68a..117e0d75d0e 100644 --- a/requirements/doctest.txt +++ b/requirements/doctest.txt @@ -3,4 +3,4 @@ pytest >=6.0.0, <7.5.0 pytest-doctestplus >=0.9.0, <=0.13.0 -pytest-rerunfailures >=10.0, <12.0 +pytest-rerunfailures >=10.0, <13.0 diff --git a/requirements/test.txt b/requirements/test.txt index 76745015187..2c25761a0af 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -5,7 +5,7 @@ coverage ==7.2.7 pytest ==7.4.0 pytest-cov ==4.1.0 pytest-doctestplus ==0.13.0 -pytest-rerunfailures ==11.1.2 +pytest-rerunfailures ==12.0 pytest-timeout ==2.1.0 phmdoctest ==1.4.0 diff --git a/requirements/visual.txt b/requirements/visual.txt index d854f70ee4a..0b61f800282 100644 --- a/requirements/visual.txt +++ b/requirements/visual.txt @@ -1,5 +1,5 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -matplotlib >=3.2.0, <=3.7.1 +matplotlib >=3.2.0, <=3.7.2 SciencePlots >= 2.0.0, <= 2.1.0 diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 94ca84877f0..8d0157fe7aa 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -86,7 +86,7 @@ def _binary_auroc_compute( pos_label: int = 1, ) -> Tensor: fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) - if max_fpr is None or max_fpr == 1: + if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0: return _auc_compute_without_check(fpr, tpr, 1.0) _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 0fdf7834da0..e51f564b02f 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -393,3 +393,17 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) +def test_corner_case_max_fpr(max_fpr): + """Check that metric returns 0 when one class is missing and `max_fpr` is set.""" + preds = torch.tensor([0.1, 0.2, 0.3, 0.4]) + target = torch.tensor([0, 0, 0, 0]) + metric = BinaryAUROC(max_fpr=max_fpr) + assert metric(preds, target) == 0.0 + + preds = torch.tensor([0.5, 0.6, 0.7, 0.8]) + target = torch.tensor([1, 1, 1, 1]) + metric = BinaryAUROC(max_fpr=max_fpr) + assert metric(preds, target) == 0.0