Skip to content

BinaryAccuracy() sometimes gives incorrect answers due to non-deterministic sigmoiding  #1604

Open
@idc9

Description

@idc9

🐛 Bug

torchmetrics.classification.BinaryAccuracy will apply a sigmoid to some inputs but not others leading to incorrect behavior.

Details

The current behavior of BinaryAccuracy() is to apply a sigmoid transformation if the inputs are outside of [0, 1] before binarizing

If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

i.e.
y_hat = 1(sigmoid(z) >= threshold) if z outside [0, 1]
y_hat = 1(z >= threshold) if z inside [0, 1]

I assume z inside [0, 1] is checked for then entire batch (i.e. if one element of the batch is outside [0, 1] then we apply the sigmoid to everyone).

This will cause silent errors. In particular, if the user inputs logits then they expect the logits to always be sigmoided. However, it is totally possible for all of the logits to lie in [0, 1] for some batches in which case the input will not be sigmoided which will cause incorrect thresholding.

To Reproduce

Here is a simple example. Support our network outputs logits.

from torchmetrics.classification import BinaryAccuracy
from scipy.special import expit # expit = sigmoid
import numpy as np
import torch

This example should lead to a correct prediction

probability_thresh = 0.5 
logits = np.array([0.49]) # network output
target = np.array([1])

# logits of 0.49 give a probability of 0.62 indicating class 1, the correct prediction
expit(logits)
array([0.62010643])
int(expit(logits) >= probability_thresh) == target
True

BinaryAccuracy() however thinks it's an incorrect prediction~

# torchmetrics, however, thinks we have the inccorect prediction because it does NOT sigmoid the logits
ba = BinaryAccuracy(threshold=probability_thresh) 
ba.forward(preds=torch.tensor(logits), target=torch.tensor(target))
tensor(0.)

Suggested Fix

I suggest adding an argument indicating whether or not the input predictions are sigmoided so the inputs are either always sigmoided or never sigmoided

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions