-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
262 lines (191 loc) · 8.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import os
import math
import argparse
import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from torch.nn import DataParallel as DP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
from datetime import datetime
from sysbinder import SysBinderImageAutoEncoder
from data import GlobDataset
from utils import linear_warmup, cosine_anneal
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=40)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--image_size', type=int, default=128)
parser.add_argument('--image_channels', type=int, default=3)
parser.add_argument('--checkpoint_path', default='checkpoint.pt.tar')
parser.add_argument('--data_path', default='data/*.png')
parser.add_argument('--log_path', default='logs/')
parser.add_argument('--lr_dvae', type=float, default=3e-4)
parser.add_argument('--lr_enc', type=float, default=1e-4)
parser.add_argument('--lr_dec', type=float, default=3e-4)
parser.add_argument('--lr_warmup_steps', type=int, default=30000)
parser.add_argument('--lr_half_life', type=int, default=250000)
parser.add_argument('--clip', type=float, default=0.05)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--num_iterations', type=int, default=3)
parser.add_argument('--num_slots', type=int, default=4)
parser.add_argument('--num_blocks', type=int, default=8)
parser.add_argument('--cnn_hidden_size', type=int, default=512)
parser.add_argument('--slot_size', type=int, default=2048)
parser.add_argument('--mlp_hidden_size', type=int, default=192)
parser.add_argument('--num_prototypes', type=int, default=64)
parser.add_argument('--vocab_size', type=int, default=4096)
parser.add_argument('--num_decoder_layers', type=int, default=8)
parser.add_argument('--num_decoder_heads', type=int, default=4)
parser.add_argument('--d_model', type=int, default=192)
parser.add_argument('--dropout', type=int, default=0.1)
parser.add_argument('--tau_start', type=float, default=1.0)
parser.add_argument('--tau_final', type=float, default=0.1)
parser.add_argument('--tau_steps', type=int, default=30000)
parser.add_argument('--use_dp', default=True, action='store_true')
args = parser.parse_args()
torch.manual_seed(args.seed)
arg_str_list = ['{}={}'.format(k, v) for k, v in vars(args).items()]
arg_str = '__'.join(arg_str_list)
log_dir = os.path.join(args.log_path, datetime.today().isoformat())
writer = SummaryWriter(log_dir)
writer.add_text('hparams', arg_str)
def visualize(image, recon_dvae, recon_tf, attns, N=8):
# tile
tiles = torch.cat((
image[:N, None, :, :, :],
recon_dvae[:N, None, :, :, :],
recon_tf[:N, None, :, :, :],
attns[:N, :, :, :, :]
), dim=1).flatten(end_dim=1)
# grid
grid = vutils.make_grid(tiles, nrow=(1 + 1 + 1 + args.num_slots), pad_value=0.8)
return grid
train_dataset = GlobDataset(root=args.data_path, phase='train', img_size=args.image_size)
val_dataset = GlobDataset(root=args.data_path, phase='val', img_size=args.image_size)
train_sampler = None
val_sampler = None
loader_kwargs = {
'batch_size': args.batch_size,
'shuffle': True,
'num_workers': args.num_workers,
'pin_memory': True,
'drop_last': True,
}
train_loader = DataLoader(train_dataset, sampler=train_sampler, **loader_kwargs)
val_loader = DataLoader(val_dataset, sampler=val_sampler, **loader_kwargs)
train_epoch_size = len(train_loader)
val_epoch_size = len(val_loader)
log_interval = train_epoch_size // 5
model = SysBinderImageAutoEncoder(args)
if os.path.isfile(args.checkpoint_path):
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
start_epoch = checkpoint['epoch']
best_val_loss = checkpoint['best_val_loss']
best_epoch = checkpoint['best_epoch']
model.load_state_dict(checkpoint['model'])
else:
checkpoint = None
start_epoch = 0
best_val_loss = math.inf
best_epoch = 0
model = model.cuda()
if args.use_dp:
model = DP(model)
optimizer = Adam([
{'params': (x[1] for x in model.named_parameters() if 'dvae' in x[0]), 'lr': args.lr_dvae},
{'params': (x[1] for x in model.named_parameters() if 'image_encoder' in x[0]), 'lr': 0.0},
{'params': (x[1] for x in model.named_parameters() if 'image_decoder' in x[0]), 'lr': 0.0},
])
if checkpoint is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
for epoch in range(start_epoch, args.epochs):
model.train()
for batch, image in enumerate(train_loader):
global_step = epoch * train_epoch_size + batch
tau = cosine_anneal(
global_step,
args.tau_start,
args.tau_final,
0,
args.tau_steps)
lr_warmup_factor_enc = linear_warmup(
global_step,
0.,
1.0,
0.,
args.lr_warmup_steps)
lr_warmup_factor_dec = linear_warmup(
global_step,
0.,
1.0,
0,
args.lr_warmup_steps)
lr_decay_factor = math.exp(global_step / args.lr_half_life * math.log(0.5))
optimizer.param_groups[0]['lr'] = args.lr_dvae
optimizer.param_groups[1]['lr'] = lr_decay_factor * lr_warmup_factor_enc * args.lr_enc
optimizer.param_groups[2]['lr'] = lr_decay_factor * lr_warmup_factor_dec * args.lr_dec
image = image.cuda()
optimizer.zero_grad()
(recon_dvae, cross_entropy, mse, attns) = model(image, tau)
if args.use_dp:
mse = mse.mean()
cross_entropy = cross_entropy.mean()
loss = mse + cross_entropy
loss.backward()
clip_grad_norm_(model.parameters(), args.clip, 'inf')
optimizer.step()
with torch.no_grad():
if batch % log_interval == 0:
print('Train Epoch: {:3} [{:5}/{:5}] \t Loss: {:F} \t MSE: {:F}'.format(
epoch+1, batch, train_epoch_size, loss.item(), mse.item()))
writer.add_scalar('TRAIN/loss', loss.item(), global_step)
writer.add_scalar('TRAIN/cross_entropy', cross_entropy.item(), global_step)
writer.add_scalar('TRAIN/mse', mse.item(), global_step)
writer.add_scalar('TRAIN/tau', tau, global_step)
writer.add_scalar('TRAIN/lr_dvae', optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar('TRAIN/lr_enc', optimizer.param_groups[1]['lr'], global_step)
writer.add_scalar('TRAIN/lr_dec', optimizer.param_groups[2]['lr'], global_step)
with torch.no_grad():
recon_tf = (model.module if args.use_dp else model).reconstruct_autoregressive(image[:8])
grid = visualize(image, recon_dvae, recon_tf, attns, N=8)
writer.add_image('TRAIN_recons/epoch={:03}'.format(epoch+1), grid)
with torch.no_grad():
model.eval()
val_cross_entropy = 0.
val_mse = 0.
for batch, image in enumerate(val_loader):
image = image.cuda()
(recon_dvae, cross_entropy, mse, attns) = model(image, tau)
if args.use_dp:
mse = mse.mean()
cross_entropy = cross_entropy.mean()
val_cross_entropy += cross_entropy.item()
val_mse += mse.item()
val_cross_entropy /= (val_epoch_size)
val_mse /= (val_epoch_size)
val_loss = val_mse + val_cross_entropy
writer.add_scalar('VAL/loss', val_loss, epoch+1)
writer.add_scalar('VAL/cross_entropy', val_cross_entropy, epoch + 1)
writer.add_scalar('VAL/mse', val_mse, epoch+1)
print('====> Epoch: {:3} \t Loss = {:F}'.format(epoch+1, val_loss))
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch + 1
torch.save(model.module.state_dict() if args.use_dp else model.state_dict(), os.path.join(log_dir, 'best_model.pt'))
if 50 <= epoch:
recon_tf = (model.module if args.use_dp else model).reconstruct_autoregressive(image[:8])
grid = visualize(image, recon_dvae, recon_tf, attns, N=8)
writer.add_image('VAL_recons/epoch={:03}'.format(epoch + 1), grid)
writer.add_scalar('VAL/best_loss', best_val_loss, epoch+1)
checkpoint = {
'epoch': epoch + 1,
'best_val_loss': best_val_loss,
'best_epoch': best_epoch,
'model': model.module.state_dict() if args.use_dp else model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, os.path.join(log_dir, 'checkpoint.pt.tar'))
print('====> Best Loss = {:F} @ Epoch {}'.format(best_val_loss, best_epoch))
writer.close()