-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathmeters.py
126 lines (98 loc) · 3.45 KB
/
meters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
class ProgressMeter(object):
def __init__(self, num_batches, *meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def print(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name='', fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class OnlineMeter(object):
"""Computes and stores the average and variance/std values of tensor"""
def __init__(self):
self.mean = torch.FloatTensor(1).fill_(-1)
self.M2 = torch.FloatTensor(1).zero_()
self.count = 0.
self.needs_init = True
def reset(self, x):
self.mean = x.new(x.size()).zero_()
self.M2 = x.new(x.size()).zero_()
self.count = 0.
self.needs_init = False
def update(self, x):
self.val = x
if self.needs_init:
self.reset(x)
self.count += 1
delta = x - self.mean
self.mean.add_(delta / self.count)
delta2 = x - self.mean
self.M2.add_(delta * delta2)
@property
def var(self):
if self.count < 2:
return self.M2.clone().zero_()
return self.M2 / (self.count - 1)
@property
def std(self):
return self.var().sqrt()
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AccuracyMeter(object):
"""Computes and stores the average and current topk accuracy"""
def __init__(self, topk=(1,)):
self.topk = topk
self.reset()
def reset(self):
self._meters = {}
for k in self.topk:
self._meters[k] = AverageMeter()
def update(self, output, target):
n = target.nelement()
acc_vals = accuracy(output, target, self.topk)
for i, k in enumerate(self.topk):
self._meters[k].update(acc_vals[i])
@property
def val(self):
return {n: meter.val for (n, meter) in self._meters.items()}
@property
def avg(self):
return {n: meter.avg for (n, meter) in self._meters.items()}
@property
def avg_error(self):
return {n: 100. - meter.avg for (n, meter) in self._meters.items()}