-
Notifications
You must be signed in to change notification settings - Fork 455
Description
🐛 Bug
The method torchmetrics.functional.classification.binary_auroc
has at least two bugs when run on an MPS device. Bug 1 seems more serious than bug 2, and I suspect the bugs are related.
To demonstrate the two bugs we will use the following code.
import torch
from torchmetrics.functional.classification import binary_auroc
torch.manual_seed(42)
device_cpu = torch.device("cpu")
device_mps = torch.device("mps")
def test_auroc(n, thresholds_cpu, thresholds_mps):
preds_cpu = torch.rand(n, device=device_cpu)
target_cpu = torch.rand(n, device=device_cpu).round().int()
preds_mps = preds_cpu.clone().to(device_mps)
target_mps = target_cpu.clone().to(device_mps)
auroc_cpu = binary_auroc(preds_cpu, target_cpu, thresholds=thresholds_cpu)
auroc_mps = binary_auroc(preds_mps, target_mps, thresholds=thresholds_mps)
print("CPU AUROC: ", auroc_cpu.item())
print("MPS AUROC: ", auroc_mps.item())
Bug 1: using no thresholds and enough data always gives 0 AUROC score on MPS device
Calling
test_auroc(2**16, None, None)
prints
CPU AUROC: 0.4983136057853699
MPS AUROC: 0.4983136057853699
which seems reasonable. However, calling
test_auroc(2**17, None, None)
prints
CPU AUROC: 0.49919775128364563
MPS AUROC: 0.0
which seems wrong. I would have expected the MPS AUROC score to be identical to the CPU AUROC score, and with even greater certainty I would have expected both AUROC scores to be much closer to 0.5 than to 0.
Bug 2: using thresholds and sufficiently small data gives inconsistent AUROC score on MPS device
Calling
test_auroc(2**17, 128, 128)
prints
CPU AUROC: 0.4991825520992279
MPS AUROC: 0.4991825520992279
which seems reasonable. However, calling
test_auroc(2**16, 128, 128)
prints
CPU AUROC: 0.4983268976211548
MPS AUROC: 0.4983268678188324
which seems wrong. I would have expected the MPS AUROC score to be identical to the CPU AUROC score, but it deviates slightly. I am not sure whether or not to expect such a deviation, but I'm surprised it deviates for small inputs but not for large. This could be due to the deviations canceling out for sufficiently large input, but since this happens at exactly the same input size as the the one triggering bug 1, I suspect bug 2 is related to bug 1.
Bug 3: wrong type hint in return type in binary_auroc
The return type hint for torchmetrics.functional.classification.binary_auroc
says it returns a Tuple[Tensor, Tensor, Tensor]
, but as seen in the above code it really returns just a Tensor
. I believe the type hint should be corrected to Union[Tensor, Tuple[Tensor, Tensor, Tensor]]
?
Environment
- TorchMetrics version (from
pip
): 0.11.4 - Python version: 3.11.2
- PyTorch Version: 2.0.0
- OS: Mac