|
| 1 | +import sys |
| 2 | +import os.path |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from train import define_argparser |
| 7 | +from train import main |
| 8 | + |
| 9 | + |
| 10 | +def overwrite_config(config, prev_config): |
| 11 | + # This method provides a compatibility for new or missing arguments. |
| 12 | + for prev_key in vars(prev_config).keys(): |
| 13 | + if not prev_key in vars(config).keys(): |
| 14 | + # No such argument in current config. Ignore that value. |
| 15 | + print('WARNING!!! Argument "--%s" is not found in current argument parser.\tIgnore saved value:' % prev_key, |
| 16 | + vars(prev_config)[prev_key]) |
| 17 | + |
| 18 | + for key in vars(config).keys(): |
| 19 | + if not key in vars(prev_config).keys(): |
| 20 | + # No such argument in saved file. Use current value. |
| 21 | + print('WARNING!!! Argument "--%s" is not found in saved model.\tUse current value:' % key, |
| 22 | + vars(config)[key]) |
| 23 | + elif vars(config)[key] != vars(prev_config)[key]: |
| 24 | + if '--%s' % key in sys.argv: |
| 25 | + # User changed argument value at this execution. |
| 26 | + print('WARNING!!! You changed value for argument "--%s".\tUse current value:' % key, |
| 27 | + vars(config)[key]) |
| 28 | + else: |
| 29 | + # User didn't changed at this execution, but current config and saved config is different. |
| 30 | + # This may caused by user's intension at last execution. |
| 31 | + # Load old value, and replace current value. |
| 32 | + vars(config)[key] = vars(prev_config)[key] |
| 33 | + |
| 34 | + return config |
| 35 | + |
| 36 | + |
| 37 | +def continue_main(config, main): |
| 38 | + # If the model exists, load model and configuration to continue the training. |
| 39 | + if os.path.isfile(config.load_fn): |
| 40 | + saved_data = torch.load(config.load_fn, map_location='cpu') |
| 41 | + |
| 42 | + prev_config = saved_data['config'] |
| 43 | + config = overwrite_config(config, prev_config) |
| 44 | + |
| 45 | + model_weight = saved_data['model'] |
| 46 | + opt_weight = saved_data['opt'] |
| 47 | + |
| 48 | + main(config, model_weight=model_weight, opt_weight=opt_weight) |
| 49 | + else: |
| 50 | + print('Cannot find file %s' % config.load_fn) |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == '__main__': |
| 54 | + config = define_argparser(is_continue=True) |
| 55 | + continue_main(config, main) |
0 commit comments