-
Notifications
You must be signed in to change notification settings - Fork 312
/
train.py
166 lines (129 loc) · 7.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""This script is the training script for Deep3DFaceRecon_pytorch
"""
import os
import time
import numpy as np
import torch
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import MyVisualizer
from util.util import genvalconf
import torch.multiprocessing as mp
import torch.distributed as dist
def setup(rank, world_size, port):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = port
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def main(rank, world_size, train_opt):
val_opt = genvalconf(train_opt, isTrain=False)
device = torch.device(rank)
torch.cuda.set_device(device)
use_ddp = train_opt.use_ddp
if use_ddp:
setup(rank, world_size, train_opt.ddp_port)
train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank)
train_dataset_batches, val_dataset_batches = \
len(train_dataset) // train_opt.batch_size, len(val_dataset) // val_opt.batch_size
model = create_model(train_opt) # create a model given train_opt.model and other options
model.setup(train_opt)
model.device = device
model.parallelize()
if rank == 0:
print('The batch number of training images = %d\n, \
the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches))
model.print_networks(train_opt.verbose)
visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots
total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations
t_data = 0
t_val = 0
optimize_time = 0.1
batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size
if use_ddp:
dist.barrier()
times = []
for epoch in range(train_opt.epoch_count, train_opt.n_epochs + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for train_data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
train_dataset.set_epoch(epoch)
for i, train_data in enumerate(train_dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % train_opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_iters += batch_size
epoch_iter += batch_size
torch.cuda.synchronize()
optimize_start_time = time.time()
model.set_input(train_data) # unpack train_data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
torch.cuda.synchronize()
optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
if use_ddp:
dist.barrier()
if rank == 0 and (total_iters == batch_size or total_iters % train_opt.display_freq == 0): # display images on visdom and save images to a HTML file
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
save_results=True,
add_image=train_opt.add_image)
# (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0)
if rank == 0 and (total_iters == batch_size or total_iters % train_opt.print_freq == 0): # print training losses and save logging information to the disk
losses = model.get_current_losses()
visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
visualizer.plot_current_losses(total_iters, losses)
if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0:
with torch.no_grad():
torch.cuda.synchronize()
val_start_time = time.time()
losses_avg = {}
model.eval()
for j, val_data in enumerate(val_dataset):
model.set_input(val_data)
model.optimize_parameters(isTrain=False)
if rank == 0 and j < train_opt.vis_batch_nums:
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
dataset='val', save_results=True, count=j * val_opt.batch_size,
add_image=train_opt.add_image)
if j < train_opt.eval_batch_nums:
losses = model.get_current_losses()
for key, value in losses.items():
losses_avg[key] = losses_avg.get(key, 0) + value
for key, value in losses_avg.items():
losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches)
torch.cuda.synchronize()
eval_time = time.time() - val_start_time
if rank == 0:
visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results
visualizer.plot_current_losses(total_iters, losses_avg, dataset='val')
model.train()
if use_ddp:
dist.barrier()
if rank == 0 and (total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0): # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
print(train_opt.name) # it's useful to occasionally show the experiment name on console
save_suffix = 'iter_%d' % total_iters if train_opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
if use_ddp:
dist.barrier()
iter_data_time = time.time()
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.n_epochs, time.time() - epoch_start_time))
model.update_learning_rate() # update learning rates at the end of every epoch.
if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
if use_ddp:
dist.barrier()
if __name__ == '__main__':
import warnings
warnings.filterwarnings("ignore")
train_opt = TrainOptions().parse() # get training options
world_size = train_opt.world_size
if train_opt.use_ddp:
mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True)
else:
main(0, world_size, train_opt)