Description
🐛 Bug
torchmetrics.classification.BinaryAccuracy
will apply a sigmoid to some inputs but not others leading to incorrect behavior.
Details
The current behavior of BinaryAccuracy() is to apply a sigmoid transformation if the inputs are outside of [0, 1] before binarizing
If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.
i.e.
y_hat = 1(sigmoid(z) >= threshold) if z outside [0, 1]
y_hat = 1(z >= threshold) if z inside [0, 1]
I assume z inside [0, 1] is checked for then entire batch (i.e. if one element of the batch is outside [0, 1] then we apply the sigmoid to everyone).
This will cause silent errors. In particular, if the user inputs logits then they expect the logits to always be sigmoided. However, it is totally possible for all of the logits to lie in [0, 1] for some batches in which case the input will not be sigmoided which will cause incorrect thresholding.
To Reproduce
Here is a simple example. Support our network outputs logits.
from torchmetrics.classification import BinaryAccuracy
from scipy.special import expit # expit = sigmoid
import numpy as np
import torch
This example should lead to a correct prediction
probability_thresh = 0.5
logits = np.array([0.49]) # network output
target = np.array([1])
# logits of 0.49 give a probability of 0.62 indicating class 1, the correct prediction
expit(logits)
array([0.62010643])
int(expit(logits) >= probability_thresh) == target
True
BinaryAccuracy()
however thinks it's an incorrect prediction~
# torchmetrics, however, thinks we have the inccorect prediction because it does NOT sigmoid the logits
ba = BinaryAccuracy(threshold=probability_thresh)
ba.forward(preds=torch.tensor(logits), target=torch.tensor(target))
tensor(0.)
Suggested Fix
I suggest adding an argument indicating whether or not the input predictions are sigmoided so the inputs are either always sigmoided or never sigmoided