11
11
import torch .nn .functional as F
12
12
import torchvision .transforms as transforms
13
13
14
+ from timm .loss import SoftTargetCrossEntropy , LabelSmoothingCrossEntropy
15
+
14
16
import ops .meters as meters
15
17
16
18
17
19
@torch .no_grad ()
18
20
def test (model , n_ff , dataset ,
19
- cutoffs = (0.0 , 0.9 ), bins = np .linspace (0.0 , 1.0 , 11 ), verbose = False , period = 10 , gpu = True ):
21
+ transform = None , smoothing = 0.0 ,
22
+ cutoffs = (0.0 , 0.9 ), bins = np .linspace (0.0 , 1.0 , 11 ),
23
+ verbose = False , period = 10 , gpu = True ):
20
24
model .eval ()
21
25
model = model .cuda () if gpu else model .cpu ()
22
26
xs , ys = next (iter (dataset ))
@@ -40,15 +44,30 @@ def test(model, n_ff, dataset,
40
44
xs = xs .cuda ()
41
45
ys = ys .cuda ()
42
46
47
+ if transform is not None :
48
+ xs , ys_t = transform (xs , ys )
49
+ else :
50
+ xs , ys_t = xs , ys
51
+
52
+ if len (ys_t .shape ) > 1 :
53
+ loss_function = SoftTargetCrossEntropy ()
54
+ ys = torch .max (ys_t , dim = - 1 )[1 ]
55
+ elif smoothing > 0.0 :
56
+ loss_function = LabelSmoothingCrossEntropy (smoothing = smoothing )
57
+ else :
58
+ loss_function = nn .CrossEntropyLoss ()
59
+ loss_function = loss_function .cuda () if gpu else loss_function
60
+
43
61
# A. Predict results
44
62
ys_pred = torch .stack ([F .softmax (model (xs ), dim = 1 ) for _ in range (n_ff )])
45
63
ys_pred = torch .mean (ys_pred , dim = 0 )
46
64
65
+ ys_t = ys_t .cpu ()
47
66
ys = ys .cpu ()
48
67
ys_pred = ys_pred .cpu ()
49
68
50
69
# B. Measure Confusion Matrices
51
- nll_meter .update (F . nll_loss (torch .log (ys_pred ), ys , reduction = "none" ).numpy ())
70
+ nll_meter .update (loss_function (torch .log (ys_pred ), ys_t ).numpy ())
52
71
topk_meter .update (topk (ys .numpy (), ys_pred .numpy ()))
53
72
brier_meter .update (brier (ys .numpy (), ys_pred .numpy ()))
54
73
@@ -74,12 +93,13 @@ def test(model, n_ff, dataset,
74
93
acc_bin = [gacc (cm_bin ) for cm_bin in cms_bin ]
75
94
conf_bin = [conf_acc / (count + 1e-7 ) for count , conf_acc in zip (count_bin , conf_acc_bin )]
76
95
ece_value = ece (count_bin , acc_bin , conf_bin )
96
+ ecse_value = ecse (count_bin , acc_bin , conf_bin )
77
97
78
98
metrics = nll_value , \
79
99
cutoffs , cms , accs , uncs , ious , freqs , \
80
100
topk_value , brier_value , \
81
- count_bin , acc_bin , conf_bin , ece_value
82
- if verbose and int (step + 1 ) % period is 0 :
101
+ count_bin , acc_bin , conf_bin , ece_value , ecse_value
102
+ if verbose and int (step + 1 ) % period == 0 :
83
103
print ("%d Steps, %s" % (int (step + 1 ), repr_metrics (metrics )))
84
104
85
105
print (repr_metrics (metrics ))
@@ -99,7 +119,7 @@ def repr_metrics(metrics):
99
119
nll_value , \
100
120
cutoffs , cms , accs , uncs , ious , freqs , \
101
121
topk_value , brier_value , \
102
- count_bin , acc_bin , conf_bin , ece_value = metrics
122
+ count_bin , acc_bin , conf_bin , ece_value , ecse_value = metrics
103
123
104
124
metrics_reprs = [
105
125
"NLL: %.4f" % nll_value ,
@@ -111,6 +131,7 @@ def repr_metrics(metrics):
111
131
"Top-5: " + "%.3f %%" % (topk_value * 100 ),
112
132
"Brier: " + "%.3f" % brier_value ,
113
133
"ECE: " + "%.3f %%" % (ece_value * 100 ),
134
+ "ECE±: " + "%.3f %%" % (ecse_value * 100 ),
114
135
]
115
136
116
137
return ", " .join (metrics_reprs )
@@ -185,12 +206,12 @@ def save_metrics(metrics_dir, metrics_list):
185
206
nll_value , \
186
207
cutoffs , cms , accs , uncs , ious , freqs , \
187
208
topk_value , brier_value , \
188
- count_bin , acc_bin , conf_bin , ece_value = metrics
209
+ count_bin , acc_bin , conf_bin , ece_value , ecse_value = metrics
189
210
190
211
metrics_acc .append ([
191
212
* keys ,
192
213
nll_value , * cutoffs , * accs , * uncs , * ious , * freqs ,
193
- topk_value , brier_value , ece_value
214
+ topk_value , brier_value , ece_value , ecse_value
194
215
])
195
216
196
217
save_lists (metrics_dir , metrics_acc )
@@ -270,7 +291,7 @@ def caccs(cm):
270
291
if float (np .sum (cm , axis = 1 )[ii ]) == 0 :
271
292
acc = 0.0
272
293
else :
273
- acc = np .diag (cm )[ii ] / float (np .sum (cm , axis = 1 )[ii ])
294
+ acc = np .diag (cm )[ii ] / ( float (np .sum (cm , axis = 1 )[ii ]) + 1e-7 )
274
295
accs .append (acc )
275
296
return accs
276
297
@@ -282,27 +303,36 @@ def unconfidence(cm_certain, cm_uncertain):
282
303
inaccurate_certain = np .sum (cm_certain ) - np .diag (cm_certain ).sum ()
283
304
inaccurate_uncertain = np .sum (cm_uncertain ) - np .diag (cm_uncertain ).sum ()
284
305
285
- return inaccurate_uncertain / (inaccurate_certain + inaccurate_uncertain )
306
+ return inaccurate_uncertain / (inaccurate_certain + inaccurate_uncertain + 1e-7 )
286
307
287
308
288
309
def frequency (cm_certain , cm_uncertain ):
289
- return np .sum (cm_certain ) / (np .sum (cm_certain ) + np .sum (cm_uncertain ))
310
+ return np .sum (cm_certain ) / (np .sum (cm_certain ) + np .sum (cm_uncertain ) + 1e-7 )
290
311
291
312
292
313
def ece (count_bin , acc_bin , conf_bin ):
293
314
count_bin = np .array (count_bin )
294
315
acc_bin = np .array (acc_bin )
295
316
conf_bin = np .array (conf_bin )
296
- freq = np .nan_to_num (count_bin / sum (count_bin ))
317
+ freq = np .nan_to_num (count_bin / ( sum (count_bin ) + 1e-7 ))
297
318
ece_result = np .sum (np .absolute (acc_bin - conf_bin ) * freq )
298
319
return ece_result
299
320
300
321
322
+ def ecse (count_bin , acc_bin , conf_bin ):
323
+ count_bin = np .array (count_bin )
324
+ acc_bin = np .array (acc_bin )
325
+ conf_bin = np .array (conf_bin )
326
+ freq = np .nan_to_num (count_bin / (sum (count_bin ) + 1e-7 ))
327
+ ecse_result = np .sum ((conf_bin - acc_bin ) * freq )
328
+ return ecse_result
329
+
330
+
301
331
def confidence_histogram (ax , count_bin ):
302
332
color , alpha = "tab:green" , 0.8
303
333
centers = np .linspace (0.05 , 0.95 , 10 )
304
334
count_bin = np .array (count_bin )
305
- freq = count_bin / sum (count_bin )
335
+ freq = count_bin / ( sum (count_bin ) + 1e-7 )
306
336
307
337
ax .bar (centers * 100 , freq * 100 , width = 10 , color = color , edgecolor = "black" , alpha = alpha )
308
338
ax .set_xlim (0 , 100.0 )
@@ -322,9 +352,9 @@ def reliability_diagram(ax, accs_bins, colors="tab:red", mode=0):
322
352
323
353
ax .plot (guides_x * 100 , guides_y * 100 , linestyle = guideline_style , color = "black" )
324
354
for accs_bin , color in zip (accs_bins , colors ):
325
- if mode is 0 :
355
+ if mode == 0 :
326
356
ax .bar (centers * 100 , accs_bin * 100 , width = 10 , color = color , edgecolor = "black" , alpha = alpha )
327
- elif mode is 1 :
357
+ elif mode == 1 :
328
358
ax .plot (centers * 100 , accs_bin * 100 , color = color , marker = "o" , alpha = alpha )
329
359
else :
330
360
raise ValueError ("Invalid mode %d." % mode )
0 commit comments