Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d0f723f

Browse files
committedNov 6, 2021
Add strong data augmentation
1 parent 044e732 commit d0f723f

File tree

1 file changed

+44
-14
lines changed

1 file changed

+44
-14
lines changed
 

‎ops/tests.py

+44-14
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
import torch.nn.functional as F
1212
import torchvision.transforms as transforms
1313

14+
from timm.loss import SoftTargetCrossEntropy, LabelSmoothingCrossEntropy
15+
1416
import ops.meters as meters
1517

1618

1719
@torch.no_grad()
1820
def test(model, n_ff, dataset,
19-
cutoffs=(0.0, 0.9), bins=np.linspace(0.0, 1.0, 11), verbose=False, period=10, gpu=True):
21+
transform=None, smoothing=0.0,
22+
cutoffs=(0.0, 0.9), bins=np.linspace(0.0, 1.0, 11),
23+
verbose=False, period=10, gpu=True):
2024
model.eval()
2125
model = model.cuda() if gpu else model.cpu()
2226
xs, ys = next(iter(dataset))
@@ -40,15 +44,30 @@ def test(model, n_ff, dataset,
4044
xs = xs.cuda()
4145
ys = ys.cuda()
4246

47+
if transform is not None:
48+
xs, ys_t = transform(xs, ys)
49+
else:
50+
xs, ys_t = xs, ys
51+
52+
if len(ys_t.shape) > 1:
53+
loss_function = SoftTargetCrossEntropy()
54+
ys = torch.max(ys_t, dim=-1)[1]
55+
elif smoothing > 0.0:
56+
loss_function = LabelSmoothingCrossEntropy(smoothing=smoothing)
57+
else:
58+
loss_function = nn.CrossEntropyLoss()
59+
loss_function = loss_function.cuda() if gpu else loss_function
60+
4361
# A. Predict results
4462
ys_pred = torch.stack([F.softmax(model(xs), dim=1) for _ in range(n_ff)])
4563
ys_pred = torch.mean(ys_pred, dim=0)
4664

65+
ys_t = ys_t.cpu()
4766
ys = ys.cpu()
4867
ys_pred = ys_pred.cpu()
4968

5069
# B. Measure Confusion Matrices
51-
nll_meter.update(F.nll_loss(torch.log(ys_pred), ys, reduction="none").numpy())
70+
nll_meter.update(loss_function(torch.log(ys_pred), ys_t).numpy())
5271
topk_meter.update(topk(ys.numpy(), ys_pred.numpy()))
5372
brier_meter.update(brier(ys.numpy(), ys_pred.numpy()))
5473

@@ -74,12 +93,13 @@ def test(model, n_ff, dataset,
7493
acc_bin = [gacc(cm_bin) for cm_bin in cms_bin]
7594
conf_bin = [conf_acc / (count + 1e-7) for count, conf_acc in zip(count_bin, conf_acc_bin)]
7695
ece_value = ece(count_bin, acc_bin, conf_bin)
96+
ecse_value = ecse(count_bin, acc_bin, conf_bin)
7797

7898
metrics = nll_value, \
7999
cutoffs, cms, accs, uncs, ious, freqs, \
80100
topk_value, brier_value, \
81-
count_bin, acc_bin, conf_bin, ece_value
82-
if verbose and int(step + 1) % period is 0:
101+
count_bin, acc_bin, conf_bin, ece_value, ecse_value
102+
if verbose and int(step + 1) % period == 0:
83103
print("%d Steps, %s" % (int(step + 1), repr_metrics(metrics)))
84104

85105
print(repr_metrics(metrics))
@@ -99,7 +119,7 @@ def repr_metrics(metrics):
99119
nll_value, \
100120
cutoffs, cms, accs, uncs, ious, freqs, \
101121
topk_value, brier_value, \
102-
count_bin, acc_bin, conf_bin, ece_value = metrics
122+
count_bin, acc_bin, conf_bin, ece_value, ecse_value = metrics
103123

104124
metrics_reprs = [
105125
"NLL: %.4f" % nll_value,
@@ -111,6 +131,7 @@ def repr_metrics(metrics):
111131
"Top-5: " + "%.3f %%" % (topk_value * 100),
112132
"Brier: " + "%.3f" % brier_value,
113133
"ECE: " + "%.3f %%" % (ece_value * 100),
134+
"ECE±: " + "%.3f %%" % (ecse_value * 100),
114135
]
115136

116137
return ", ".join(metrics_reprs)
@@ -185,12 +206,12 @@ def save_metrics(metrics_dir, metrics_list):
185206
nll_value, \
186207
cutoffs, cms, accs, uncs, ious, freqs, \
187208
topk_value, brier_value, \
188-
count_bin, acc_bin, conf_bin, ece_value = metrics
209+
count_bin, acc_bin, conf_bin, ece_value, ecse_value = metrics
189210

190211
metrics_acc.append([
191212
*keys,
192213
nll_value, *cutoffs, *accs, *uncs, *ious, *freqs,
193-
topk_value, brier_value, ece_value
214+
topk_value, brier_value, ece_value, ecse_value
194215
])
195216

196217
save_lists(metrics_dir, metrics_acc)
@@ -270,7 +291,7 @@ def caccs(cm):
270291
if float(np.sum(cm, axis=1)[ii]) == 0:
271292
acc = 0.0
272293
else:
273-
acc = np.diag(cm)[ii] / float(np.sum(cm, axis=1)[ii])
294+
acc = np.diag(cm)[ii] / (float(np.sum(cm, axis=1)[ii]) + 1e-7)
274295
accs.append(acc)
275296
return accs
276297

@@ -282,27 +303,36 @@ def unconfidence(cm_certain, cm_uncertain):
282303
inaccurate_certain = np.sum(cm_certain) - np.diag(cm_certain).sum()
283304
inaccurate_uncertain = np.sum(cm_uncertain) - np.diag(cm_uncertain).sum()
284305

285-
return inaccurate_uncertain / (inaccurate_certain + inaccurate_uncertain)
306+
return inaccurate_uncertain / (inaccurate_certain + inaccurate_uncertain + 1e-7)
286307

287308

288309
def frequency(cm_certain, cm_uncertain):
289-
return np.sum(cm_certain) / (np.sum(cm_certain) + np.sum(cm_uncertain))
310+
return np.sum(cm_certain) / (np.sum(cm_certain) + np.sum(cm_uncertain) + 1e-7)
290311

291312

292313
def ece(count_bin, acc_bin, conf_bin):
293314
count_bin = np.array(count_bin)
294315
acc_bin = np.array(acc_bin)
295316
conf_bin = np.array(conf_bin)
296-
freq = np.nan_to_num(count_bin / sum(count_bin))
317+
freq = np.nan_to_num(count_bin / (sum(count_bin) + 1e-7))
297318
ece_result = np.sum(np.absolute(acc_bin - conf_bin) * freq)
298319
return ece_result
299320

300321

322+
def ecse(count_bin, acc_bin, conf_bin):
323+
count_bin = np.array(count_bin)
324+
acc_bin = np.array(acc_bin)
325+
conf_bin = np.array(conf_bin)
326+
freq = np.nan_to_num(count_bin / (sum(count_bin) + 1e-7))
327+
ecse_result = np.sum((conf_bin - acc_bin) * freq)
328+
return ecse_result
329+
330+
301331
def confidence_histogram(ax, count_bin):
302332
color, alpha = "tab:green", 0.8
303333
centers = np.linspace(0.05, 0.95, 10)
304334
count_bin = np.array(count_bin)
305-
freq = count_bin / sum(count_bin)
335+
freq = count_bin / (sum(count_bin) + 1e-7)
306336

307337
ax.bar(centers * 100, freq * 100, width=10, color=color, edgecolor="black", alpha=alpha)
308338
ax.set_xlim(0, 100.0)
@@ -322,9 +352,9 @@ def reliability_diagram(ax, accs_bins, colors="tab:red", mode=0):
322352

323353
ax.plot(guides_x * 100, guides_y * 100, linestyle=guideline_style, color="black")
324354
for accs_bin, color in zip(accs_bins, colors):
325-
if mode is 0:
355+
if mode == 0:
326356
ax.bar(centers * 100, accs_bin * 100, width=10, color=color, edgecolor="black", alpha=alpha)
327-
elif mode is 1:
357+
elif mode == 1:
328358
ax.plot(centers * 100, accs_bin * 100, color=color, marker="o", alpha=alpha)
329359
else:
330360
raise ValueError("Invalid mode %d." % mode)

0 commit comments

Comments
 (0)
Please sign in to comment.