Skip to content

Commit f3e2392

Browse files
committed
simple-nmt test
1 parent 145ae37 commit f3e2392

27 files changed

+7822
-0
lines changed

src/simple-nmt/README.md

Lines changed: 342 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import sys
2+
import os.path
3+
4+
import torch
5+
6+
from dual_train import define_argparser
7+
from dual_train import main
8+
9+
from continue_train import overwrite_config
10+
from continue_train import continue_main
11+
12+
13+
if __name__ == '__main__':
14+
config = define_argparser(is_continue=True)
15+
continue_main(config, main)

src/simple-nmt/continue_train.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import sys
2+
import os.path
3+
4+
import torch
5+
6+
from train import define_argparser
7+
from train import main
8+
9+
10+
def overwrite_config(config, prev_config):
11+
# This method provides a compatibility for new or missing arguments.
12+
for prev_key in vars(prev_config).keys():
13+
if not prev_key in vars(config).keys():
14+
# No such argument in current config. Ignore that value.
15+
print('WARNING!!! Argument "--%s" is not found in current argument parser.\tIgnore saved value:' % prev_key,
16+
vars(prev_config)[prev_key])
17+
18+
for key in vars(config).keys():
19+
if not key in vars(prev_config).keys():
20+
# No such argument in saved file. Use current value.
21+
print('WARNING!!! Argument "--%s" is not found in saved model.\tUse current value:' % key,
22+
vars(config)[key])
23+
elif vars(config)[key] != vars(prev_config)[key]:
24+
if '--%s' % key in sys.argv:
25+
# User changed argument value at this execution.
26+
print('WARNING!!! You changed value for argument "--%s".\tUse current value:' % key,
27+
vars(config)[key])
28+
else:
29+
# User didn't changed at this execution, but current config and saved config is different.
30+
# This may caused by user's intension at last execution.
31+
# Load old value, and replace current value.
32+
vars(config)[key] = vars(prev_config)[key]
33+
34+
return config
35+
36+
37+
def continue_main(config, main):
38+
# If the model exists, load model and configuration to continue the training.
39+
if os.path.isfile(config.load_fn):
40+
saved_data = torch.load(config.load_fn, map_location='cpu')
41+
42+
prev_config = saved_data['config']
43+
config = overwrite_config(config, prev_config)
44+
45+
model_weight = saved_data['model']
46+
opt_weight = saved_data['opt']
47+
48+
main(config, model_weight=model_weight, opt_weight=opt_weight)
49+
else:
50+
print('Cannot find file %s' % config.load_fn)
51+
52+
53+
if __name__ == '__main__':
54+
config = define_argparser(is_continue=True)
55+
continue_main(config, main)

0 commit comments

Comments
 (0)