-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmain.py
70 lines (57 loc) · 1.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# main.py
import sys
import traceback
import torch
import random
import config
import utils
from model import Model
from test import Tester
from train import Trainer
from dataloader import Dataloader
from checkpoints import Checkpoints
def main():
# parse the arguments
args = config.parse_args()
random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
if args.save_results:
utils.saveargs(args)
# initialize the checkpoint class
checkpoints = Checkpoints(args)
# Create Model
models = Model(args)
model, criterion, evaluation = models.setup(checkpoints)
print('Model:\n\t{model}\nTotal params:\n\t{npar:.2f}M'.format(
model=args.model_type,
npar=sum(p.numel() for p in model.parameters()) / 1000000.0))
# Data Loading
dataloader = Dataloader(args)
loaders = dataloader.create()
# The trainer handles the training loop
trainer = Trainer(args, model, criterion, evaluation)
# The trainer handles the evaluation on validation set
tester = Tester(args, model, criterion, evaluation)
# start training !!!
loss_best = 1e10
for epoch in range(args.nepochs):
print('\nEpoch %d/%d\n' % (epoch + 1, args.nepochs))
# train for a single epoch
loss_train = trainer.train(epoch, loaders)
loss_test = tester.test(epoch, loaders)
if loss_best > loss_test:
model_best = True
loss_best = loss_test
if args.save_results:
checkpoints.save(epoch, model, model_best)
if __name__ == "__main__":
utils.setup_graceful_exit()
try:
main()
except (KeyboardInterrupt, SystemExit):
# do not print stack trace when ctrl-c is pressed
pass
except Exception:
traceback.print_exc(file=sys.stdout)
finally:
utils.cleanup()