From 3f0ce01307416d310ebabbcdd2d9dacc65d852bc Mon Sep 17 00:00:00 2001 From: vjoki Date: Mon, 1 Feb 2021 01:38:43 +0200 Subject: [PATCH] Check for inaccurate EER values --- snn/librispeech/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/snn/librispeech/utils.py b/snn/librispeech/utils.py index 84adf2a..6659c18 100644 --- a/snn/librispeech/utils.py +++ b/snn/librispeech/utils.py @@ -1,4 +1,5 @@ -from typing import cast, Tuple, List +from typing import cast, Tuple, List, Dict, Optional +import warnings import torch import matplotlib.pyplot as plt from pytorch_lightning.metrics.functional import roc @@ -19,7 +20,8 @@ def minDCF(fpr: torch.Tensor, fnr: torch.Tensor, thresholds: torch.Tensor, def equal_error_rate(fpr: torch.Tensor, fnr: torch.Tensor, - thresholds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + thresholds: torch.Tensor, + warn_threshold: float = 0.1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Index of the nearest intersection point. idx = torch.argmin(torch.abs(fnr - fpr)) @@ -31,6 +33,10 @@ def equal_error_rate(fpr: torch.Tensor, fnr: torch.Tensor, eer = 0.5 * (fpr[idx] + fnr[idx]) eer_threshold = thresholds[idx] + # Since we're averaging the EER, warn if the difference between FNR and FPR is too large. + if torch.abs(fpr[idx] - fnr[idx]) > warn_threshold: + warnings.warn('Inaccurate EER ({}), real EER is somewhere between or near [{}, {}].'.format(eer, fpr[idx], fnr[idx])) + # https://yangcha.github.io/EER-ROC/ # https://stackoverflow.com/questions/28339746/equal-error-rate-in-python # Unfortunately this seems to fail frequently, incorrectly returning 0.0 and nan.