-
Notifications
You must be signed in to change notification settings - Fork 12
/
train.py
49 lines (37 loc) · 1.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import sys
import signal
import argparse
from torchvision import transforms
from misc import util
from network import Builder, Trainer
from dataset import CelebA
def parse_args():
parser = argparse.ArgumentParser(
description='PyTorch implementation of "Glow: Generative Flow with Invertible 1x1 Convolutions"')
parser.add_argument('profile', type=str,
default='profile/celeba.json',
help='path to profile file')
return parser.parse_args()
if __name__ == '__main__':
# this enables a Ctrl-C without triggering errors
signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
# parse arguments
args = parse_args()
# initialize logging
util.init_output_logging()
# load hyper-parameters
hps = util.load_profile(args.profile)
util.manual_seed(hps.ablation.seed)
# build graph
builder = Builder(hps)
state = builder.build()
# load dataset
dataset = CelebA(root=hps.dataset.root,
transform=transforms.Compose((
transforms.CenterCrop(160),
transforms.Resize(64),
transforms.ToTensor()
)))
# start training
trainer = Trainer(hps=hps, dataset=dataset, **state)
trainer.train()