Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/image_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored May 29, 2023
2 parents c35d0f0 + f3c6c64 commit de02c9d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed several bugs in `SpectralDistortionIndex` metric ([#1808](https://github.com/Lightning-AI/torchmetrics/pull/1808))


- Fixed bug for corner cases in `MatthewsCorrCoef` ([#1812](https://github.com/Lightning-AI/torchmetrics/pull/1812))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
29 changes: 26 additions & 3 deletions src/torchmetrics/functional/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@


def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor:
"""Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the matthews corrcoef score."""
"""Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the matthews corrcoef score.
See: https://bmcgenomics.biomedcentral.com/articles/10.1186/s12864-019-6413-7 for more info.
"""
# convert multilabel into binary
confmat = confmat.sum(0) if confmat.ndim == 3 else confmat

if confmat.numel() == 4: # binary case
tn, fp, fn, tp = confmat.reshape(-1)
if tp != 0 and tn == 0 and fp == 0 and fn == 0:
return torch.tensor(1.0, dtype=confmat.dtype, device=confmat.device)

if tp == 0 and tn != 0 and fp == 0 and fn == 0:
return torch.tensor(-1.0, dtype=confmat.dtype, device=confmat.device)

tk = confmat.sum(dim=-1).float()
pk = confmat.sum(dim=-2).float()
c = torch.trace(confmat).float()
Expand All @@ -48,10 +59,22 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor:
cov_ypyp = s**2 - sum(pk * pk)
cov_ytyt = s**2 - sum(tk * tk)

numerator = cov_ytyp
denom = cov_ypyp * cov_ytyt
if denom == 0:

if denom == 0 and confmat.numel() == 4:
if tp == 0 or tn == 0:
a = tp + tn

if fp == 0 or fn == 0:
b = fp + fn

eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device)
numerator = torch.sqrt(eps) * (a - b)
denom = torch.sqrt(2 * (a + b) * (a + eps) * (b + eps))
elif denom == 0:
return torch.tensor(0, dtype=confmat.dtype, device=confmat.device)
return cov_ytyp / torch.sqrt(denom)
return numerator / torch.sqrt(denom)


def binary_matthews_corrcoef(
Expand Down
39 changes: 35 additions & 4 deletions tests/unittests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _sklearn_matthews_corrcoef_binary(preds, target, ignore_index=None):
class TestBinaryMatthewsCorrCoef(MetricTester):
"""Test class for `BinaryMatthewsCorrCoef` metric."""

@pytest.mark.parametrize("ignore_index", [None, -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
@pytest.mark.parametrize("ddp", [True, False])
def test_binary_matthews_corrcoef(self, inputs, ddp, ignore_index):
"""Test class implementation of metric."""
Expand All @@ -71,7 +71,7 @@ def test_binary_matthews_corrcoef(self, inputs, ddp, ignore_index):
},
)

@pytest.mark.parametrize("ignore_index", [None, -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
def test_binary_matthews_corrcoef_functional(self, inputs, ignore_index):
"""Test functional implementation of metric."""
preds, target = inputs
Expand Down Expand Up @@ -234,7 +234,7 @@ def _sklearn_matthews_corrcoef_multilabel(preds, target, ignore_index=None):
class TestMultilabelMatthewsCorrCoef(MetricTester):
"""Test class for `MultilabelMatthewsCorrCoef` metric."""

@pytest.mark.parametrize("ignore_index", [None, -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
@pytest.mark.parametrize("ddp", [True, False])
def test_multilabel_matthews_corrcoef(self, inputs, ddp, ignore_index):
"""Test class implementation of metric."""
Expand All @@ -253,7 +253,7 @@ def test_multilabel_matthews_corrcoef(self, inputs, ddp, ignore_index):
},
)

@pytest.mark.parametrize("ignore_index", [None, -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
def test_multilabel_matthews_corrcoef_functional(self, inputs, ignore_index):
"""Test functional implementation of metric."""
preds, target = inputs
Expand Down Expand Up @@ -316,3 +316,34 @@ def test_zero_case_in_multiclass():
# Example where neither 1 or 2 is present in the target tensor
out = multiclass_matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3)
assert out == 0.0


@pytest.mark.parametrize(
("metric_fn", "preds", "target", "expected"),
[
(binary_matthews_corrcoef, torch.zeros(10), torch.zeros(10), -1.0),
(binary_matthews_corrcoef, torch.ones(10), torch.ones(10), 1.0),
(
binary_matthews_corrcoef,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]),
0.0,
),
(
partial(multilabel_matthews_corrcoef, num_labels=NUM_CLASSES),
torch.zeros(10, NUM_CLASSES).long(),
torch.zeros(10, NUM_CLASSES).long(),
-1.0,
),
(
partial(multilabel_matthews_corrcoef, num_labels=NUM_CLASSES),
torch.ones(10, NUM_CLASSES).long(),
torch.ones(10, NUM_CLASSES).long(),
1.0,
),
],
)
def test_corner_cases(metric_fn, preds, target, expected):
"""Test the corner cases of perfect classifiers or completely random classifiers that they work as expected."""
out = metric_fn(preds, target)
assert out == expected

0 comments on commit de02c9d

Please sign in to comment.