-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_net.py
98 lines (80 loc) · 3.13 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Written by Willy, Weiyuan
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from src.lib.opt.train import TrainModel
#from src.lib.opt.train_adv import TrainModel
from src.lib.network.cn import vgg16
import yaml
from easydict import EasyDict as edict
import sys,getopt
def initGenFeatFromCfg(cfg_file):
# Load cfg parameter from yaml file
with open(cfg_file, 'r') as f:
cfg = edict(yaml.load(f))
# Fist load the dataset name
dataset = cfg.DATASET
# Set values
#path
data_path = cfg[dataset].DATA_PATH
tensor_server_path = cfg[dataset].TENSOR_BOARD_PATH
pre_trained_path = cfg[dataset].PRE_TRAINED_PATH
#network
batch_size = cfg[dataset].BATCH_SIZE
lr = float(cfg[dataset].LEARNING_RATE)
epoch_num = cfg[dataset].EPOCH_NUM
steps = cfg[dataset].STEPS
decay_rate = cfg[dataset].DECAY_RATE
start_epoch = cfg[dataset].START_EPOCH
snap_shot = cfg[dataset].SNAP_SHOT
resize = cfg[dataset].RESIZE
val_size = cfg[dataset].VAL_SIZE
return dataset, data_path, tensor_server_path, pre_trained_path, batch_size, lr, epoch_num, steps, decay_rate, start_epoch, snap_shot, resize, val_size
def dispHelp():
print("======================================================")
print(" Usage")
print("======================================================")
print("\t-h display this message")
print("\t--cfg <config file yaml>")
def main(argv):
cfg_file = 'model/shanghaitech.yml'
# Get parameters
try:
opts, _ = getopt.getopt(argv, "h:", ["cfg="])
except getopt.GetoptError:
dispHelp()
return
for opt, arg in opts:
if opt == '-h':
dispHelp(argv[0])
return
elif opt in ("--cfg"):
cfg_file = arg
print("Loading configuration file: ", cfg_file)
(dataset, data_path, tensor_server_path, pre_trained_path, batch_size, lr, epoch_num, steps, decay_rate, start_epoch, snap_shot, resize, val_size) = initGenFeatFromCfg(cfg_file)
print("Choosen parameters:")
print("-------------------")
print("Dataset: ", dataset)
print("Data location: ", data_path)
print("Tensorboard server root: ", tensor_server_path)
print("Pre-trained model path:", pre_trained_path)
print("Batch size:", batch_size)
print("Learning rate: ", lr)
print("Total epoch number: ", epoch_num)
print("Learning rate steps: ", steps)
print("Learning rate decay rate: ", decay_rate)
print("Start epoch: ", start_epoch)
print("Snap shot: ", snap_shot)
print("Image resize: ", resize)
print("Validation size: ", val_size)
print("")
print("===================")
print("")
tm = TrainModel(data_path=data_path, batchsize=batch_size, lr=lr, epoch=epoch_num, snap_shot=snap_shot,
server_root_path=tensor_server_path, start_epoch=start_epoch, steps=steps,
decay_rate=decay_rate, branch=vgg16, pre_trained=pre_trained_path,resize=resize, val_size=val_size)
tm.run()
if __name__ == '__main__':
main(sys.argv[1:])