-
Notifications
You must be signed in to change notification settings - Fork 86
/
trainer.py
76 lines (55 loc) · 2.16 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import pdb
from tqdm import tqdm
from collections import defaultdict
import torch
import torch.nn as nn
class Trainer (nn.Module):
""" Helper class to train a deep network.
Overload this class `forward_backward` for your actual needs.
Usage:
train = Trainer(net, loader, loss, optimizer)
for epoch in range(n_epochs):
train()
"""
def __init__(self, net, loader, loss, optimizer):
nn.Module.__init__(self)
self.net = net
self.loader = loader
self.loss_func = loss
self.optimizer = optimizer
def iscuda(self):
return next(self.net.parameters()).device != torch.device('cpu')
def todevice(self, x):
if isinstance(x, dict):
return {k:self.todevice(v) for k,v in x.items()}
if isinstance(x, (tuple,list)):
return [self.todevice(v) for v in x]
if self.iscuda():
return x.contiguous().cuda(non_blocking=True)
else:
return x.cpu()
def __call__(self):
self.net.train()
stats = defaultdict(list)
for iter,inputs in enumerate(tqdm(self.loader)):
inputs = self.todevice(inputs)
# compute gradient and do model update
self.optimizer.zero_grad()
loss, details = self.forward_backward(inputs)
if torch.isnan(loss):
raise RuntimeError('Loss is NaN')
self.optimizer.step()
for key, val in details.items():
stats[key].append( val )
print(" Summary of losses during this epoch:")
mean = lambda lis: sum(lis) / len(lis)
for loss_name, vals in stats.items():
N = 1 + len(vals)//10
print(f" - {loss_name:20}:", end='')
print(f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})")
return mean(stats['loss']) # return average loss
def forward_backward(self, inputs):
raise NotImplementedError()