Skip to content

Commit

Permalink
Merge branch 'master' into sa-sdr
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jul 11, 2023
2 parents 60785f9 + 703d5b8 commit 5eecebc
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895))


- Fixed bug related to empty predictions for `IntersectionOverUnion` metric ([#1892](https://github.com/Lightning-AI/torchmetrics/pull/1892))


Expand Down
2 changes: 1 addition & 1 deletion requirements/doctest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

pytest >=6.0.0, <7.5.0
pytest-doctestplus >=0.9.0, <=0.13.0
pytest-rerunfailures >=10.0, <12.0
pytest-rerunfailures >=10.0, <13.0
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ coverage ==7.2.7
pytest ==7.4.0
pytest-cov ==4.1.0
pytest-doctestplus ==0.13.0
pytest-rerunfailures ==11.1.2
pytest-rerunfailures ==12.0
pytest-timeout ==2.1.0
phmdoctest ==1.4.0

Expand Down
2 changes: 1 addition & 1 deletion requirements/visual.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

matplotlib >=3.2.0, <=3.7.1
matplotlib >=3.2.0, <=3.7.2
SciencePlots >= 2.0.0, <= 2.1.0
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _binary_auroc_compute(
pos_label: int = 1,
) -> Tensor:
fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
if max_fpr is None or max_fpr == 1:
if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0:
return _auc_compute_without_check(fpr, tpr, 1.0)

_device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device
Expand Down
14 changes: 14 additions & 0 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,17 @@ def test_valid_input_thresholds(metric, thresholds):
with pytest.warns(None) as record:
metric(thresholds=thresholds)
assert len(record) == 0


@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5])
def test_corner_case_max_fpr(max_fpr):
"""Check that metric returns 0 when one class is missing and `max_fpr` is set."""
preds = torch.tensor([0.1, 0.2, 0.3, 0.4])
target = torch.tensor([0, 0, 0, 0])
metric = BinaryAUROC(max_fpr=max_fpr)
assert metric(preds, target) == 0.0

preds = torch.tensor([0.5, 0.6, 0.7, 0.8])
target = torch.tensor([1, 1, 1, 1])
metric = BinaryAUROC(max_fpr=max_fpr)
assert metric(preds, target) == 0.0

0 comments on commit 5eecebc

Please sign in to comment.