Skip to content

Commit

Permalink
small code style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SunQpark committed May 2, 2019
1 parent 0b5e207 commit 940da3b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BaseModel(nn.Module):
Base class for all models
"""
@abstractmethod
def forward(self, *input):
def forward(self, *inputs):
"""
Forward pass logic
Expand Down
10 changes: 5 additions & 5 deletions parse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def __init__(self, args, options='', timestamp=True):

# load config file and apply custom cli options
config = read_json(self.cfg_fname)
self.__config = _update_config(config, options, args)
self._config = _update_config(config, options, args)

# set save_dir where trained model and log will be saved.
save_dir = Path(self.config['trainer']['save_dir'])
timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''

exper_name = self.config['name']
self.__save_dir = save_dir / 'models' / exper_name / timestamp
self.__log_dir = save_dir / 'log' / exper_name / timestamp
self._save_dir = save_dir / 'models' / exper_name / timestamp
self._log_dir = save_dir / 'log' / exper_name / timestamp

self.save_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -80,11 +80,11 @@ def config(self):

@property
def save_dir(self):
return self.__save_dir
return self._save_dir

@property
def log_dir(self):
return self.__log_dir
return self._log_dir

# helper functions used to update config dict with custom cli options
def _update_config(config, options, args):
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def main(config):
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
help='indices of GPUs to enable (default: all)')

# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
Expand Down

0 comments on commit 940da3b

Please sign in to comment.