Skip to content

Commit 0bc32c4

Browse files
committed
Merge branch 'master' of github.com:victoresque/pytorch-template into update
2 parents 36ffcb3 + a432181 commit 0bc32c4

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

base/base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def train(self):
6161
"""
6262
Full training logic
6363
"""
64+
not_improved_count = 0
6465
for epoch in range(self.start_epoch, self.epochs + 1):
6566
result = self._train_epoch(epoch)
6667

@@ -90,7 +91,6 @@ def train(self):
9091
"Model performance monitoring is disabled.".format(self.mnt_metric))
9192
self.mnt_mode = 'off'
9293
improved = False
93-
not_improved_count = 0
9494

9595
if improved:
9696
self.mnt_best = log[self.mnt_metric]

parse_config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,20 @@ def __init__(self, args, options='', timestamp=True):
1717

1818
if args.device:
1919
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
20-
if args.resume:
21-
self.resume = Path(args.resume)
22-
self.cfg_fname = self.resume.parent / 'config.json'
23-
else:
20+
if args.resume is None:
2421
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
2522
assert args.config is not None, msg_no_cfg
26-
self.resume = None
2723
self.cfg_fname = Path(args.config)
24+
config = read_json(self.cfg_fname)
25+
self.resume = None
26+
else:
27+
self.resume = Path(args.resume)
28+
resume_cfg_fname = self.resume.parent / 'config.json'
29+
config = read_json(resume_cfg_fname)
30+
if args.config is not None:
31+
config.update(read_json(Path(args.config)))
2832

2933
# load config file and apply custom cli options
30-
config = read_json(self.cfg_fname)
3134
self._config = _update_config(config, options, args)
3235

3336
# set save_dir where trained model and log will be saved.

0 commit comments

Comments
 (0)