-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrain_IGRs.py
161 lines (139 loc) · 6.14 KB
/
train_IGRs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Training the coordinate localization sub-network.
Author: Shichao Li
Contact: nicholas.li@connect.ust.hk
"""
import sys
sys.path.append('../')
import torch
import os
import libs.arguments.parse as parse
import libs.logger.logger as liblogger
import libs.dataset as dataset
# import libs.dataset.ApolloScape.car_instance
import libs.dataset.KITTI.car_instance
import libs.trainer.trainer as trainer
import libs.model as models
import libs.optimizer.optimizer as optimizer
import libs.loss.function as loss_func
from libs.common.utils import get_model_summary
from libs.metric.criterions import get_distance_src, get_angle_error
from libs.metric.criterions import Evaluator
def choose_loss_func(model_settings, cfgs):
"""
Initialize the loss function used for training.
"""
loss_type = model_settings['loss_type']
if loss_type == 'JointsCompositeLoss':
spec_list = model_settings['loss_spec_list']
loss_weights = model_settings['loss_weight_list']
func = loss_func.JointsCompositeLoss(spec_list=spec_list,
img_size=model_settings['input_size'],
hm_size=model_settings['heatmap_size'],
cr_loss_thres=model_settings['cr_loss_threshold'],
loss_weights=loss_weights
)
else:
func = eval('loss_func.' + loss_type)(use_target_weight=cfgs['training_settings']['use_target_weight'])
# the order of the points are needed when computing the cross-ratio loss
if model_settings['loss_spec_list'][2] != 'None':
func.cr_indices = libs.dataset.KITTI.car_instance.cr_indices_dict['bbox12']
func.target_cr = 4/3
return func.cuda()
def train(model, model_settings, GPUs, cfgs, logger, final_output_dir):
"""
The training method.
"""
# get model summary
input_size = model_settings['input_size']
input_channels = 5 if cfgs['heatmapModel']['add_xy'] else 3
dump_input = torch.rand((1, input_channels, input_size[1], input_size[0]))
logger.info(get_model_summary(model, dump_input))
model = torch.nn.DataParallel(model, device_ids=GPUs).cuda()
# get forward-pass time if you need
# import time
# dump_input = torch.rand((64, input_channels, input_size[1], input_size[0])).cuda()
# t1 = time.clock()
# out = model(dump_input)
# l = out[0].sum()
# l.backward()
# torch.cuda.synchronize()
# print(time.clock() - t1)
# specify loss function
func = choose_loss_func(model_settings, cfgs)
# dataset preparation
data_cfgs = cfgs['dataset']
train_dataset, valid_dataset = eval('dataset.' + data_cfgs['name'] +
'.car_instance').prepare_data(cfgs, logger)
# get the optimizer and learning rate scheduler
optim, sche = optimizer.prepare_optim(model, cfgs)
# metrics used for training error
if cfgs['exp_type'] in ['baselinealpha', 'baselinetheta']:
metric_function = get_angle_error
save_debug_images = False
elif cfgs['exp_type'] == 'instanceto2d':
metric_function = get_distance_src
save_debug_images = cfgs['training_settings']['debug']['save']
collate_fn = train_dataset.get_collate_fn()
trainer.train(train_dataset=train_dataset,
valid_dataset=valid_dataset,
model=model,
loss_func=func,
optim=optim,
sche=sche,
metric_func=metric_function,
cfgs=cfgs,
logger=logger,
collate_fn=collate_fn,
save_debug=save_debug_images
)
final_model_state_file = os.path.join(final_output_dir, 'HC.pth')
logger.info('=> saving final model state to {}'.format(final_model_state_file))
torch.save(model.module.state_dict(), final_model_state_file)
return
def evaluate(model, model_settings, GPUs, cfgs, logger, final_output_dir, eval_train=False):
saved_path = cfgs['dirs']['load_hm_model']
model.load_state_dict(torch.load(saved_path))
model = torch.nn.DataParallel(model, device_ids=GPUs).cuda()
evaluator = Evaluator(cfgs['testing_settings']['eval_metrics'], cfgs)
# define loss function (criterion) and optimizer
loss_func = choose_loss_func(model_settings, cfgs)
# dataset preparation
data_cfgs = cfgs['dataset']
train_dataset, valid_dataset = eval('dataset.' + data_cfgs['name'] +
'.car_instance').prepare_data(cfgs, logger)
collate_fn = valid_dataset.get_collate_fn()
logger.info("Evaluation on the validation split:")
trainer.evaluate(valid_dataset, model, loss_func, cfgs, logger, evaluator, collate_fn=collate_fn)
if eval_train:
logger.info("Evaluation on the training split:")
trainer.evaluate(train_dataset, model, loss_func, cfgs, logger, evaluator, collate_fn=collate_fn)
return
def main():
# experiment configurations
cfgs = parse.parse_args()
# logging
logger, final_output_dir = liblogger.get_logger(cfgs)
# Set GPU
if cfgs['use_gpu'] and torch.cuda.is_available():
GPUs = cfgs['gpu_id']
else:
logger.info("GPU acceleration is disabled.")
if len(GPUs) == 1:
torch.cuda.set_device(GPUs[0])
# cudnn related setting
torch.backends.cudnn.benchmark = cfgs['cudnn']['benchmark']
torch.backends.cudnn.deterministic = cfgs['cudnn']['deterministic']
torch.backends.cudnn.enabled = cfgs['cudnn']['enabled']
# model initialization
model_settings = cfgs['heatmapModel']
model_name = model_settings['name']
method_str = 'models.heatmapModel' + '.' + model_name + '.get_pose_net'
model = eval(method_str)(cfgs, is_train=cfgs['train'])
if cfgs['train']:
train(model, model_settings, GPUs, cfgs, logger, final_output_dir)
elif cfgs['evaluate']:
evaluate(model, model_settings, GPUs, cfgs, logger, final_output_dir)
if __name__ == '__main__':
main()
torch.cuda.empty_cache()