-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscript_run_train.py
92 lines (65 loc) · 2.36 KB
/
script_run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import yaml
import numpy as np
CMD = 'train_knn.py'
def run_script(**arg):
train_arg_list = []
if "model" in arg:
model = arg['model']
train_arg_list.append('--model')
train_arg_list.append(model)
session = arg['session']
cmd = arg['cmd']
train_arg_list.append('--sess_cfg')
train_arg_list.append(session)
if 'results' in arg:
value = arg['results']
train_arg_list.append('--results')
train_arg_list.append(value)
if 'plot' in arg:
value = str(arg['plot'])
train_arg_list.append('--plot')
train_arg_list.append(value)
# Add pretrained if it exists
if "pretrained" in arg:
value = arg['pretrained']
if os.path.isfile(value + '.pth') == True:
train_arg_list.append('--pretrained')
train_arg_list.append(value)
# Convert arguments to str line
train_arg = ' '.join(train_arg_list)
# Build Full terminal command
terminal_cmd_list = ['python.exe','-W','ignore' , cmd, train_arg]
terminal_cmd = ' '.join(terminal_cmd_list)
print("\n\n======================================================")
print("======================================================\n\n")
print("[INF] $: %s\n"%(terminal_cmd))
os.system(terminal_cmd)
def statup_session(**arg):
#session = arg['session']
models = arg['model']
root = arg['root']
sequences = arg['sequences']
type_ = arg['type_']
plot = arg['plot'] if 'plot' in arg else 0
for model in models:
for ex in sequences:
# Build Argument
s = '%02d'%(int(ex[-1]))
session = type_ + '_' + s
pretrained = root + model +'_' + session
run_script( cmd = CMD,
model = model,
session = session,
pretrained = pretrained,
plot = plot
)
if __name__ == '__main__':
TYPE_ = 'cross_val'
root = "checkpoints/"
session = 'cosine_small_session'
model = '2bb_1a_norm'
network = yaml.load(open('model_cfg/'+ model + '.yaml'),Loader=yaml.FullLoader)
with open('model_cfg/model.yaml', 'w') as file:
documents = yaml.dump(network, file)
run_script(cmd = CMD,model = 'model', session = session,plot=1)