From 750a36e203abc732e17f25dd335c7e121c9725cc Mon Sep 17 00:00:00 2001 From: Piotr Dollar Date: Thu, 8 Jul 2021 17:41:44 -0700 Subject: [PATCH] minor refactor to get rid of cache_cfg_urls (#155) Summary: The reason for this refactor is that cache_cfg_urls() alter the global config and does not play nicely with multi-node training (coming soon). So now file caching occurs inside of trainer.py, where/when needed. Pull Request resolved: https://github.com/facebookresearch/pycls/pull/155 Reviewed By: mannatsingh Differential Revision: D29627547 Pulled By: pdollar fbshipit-source-id: f397ec9ebe38733d204b95cca7c69618aeadba53 Co-authored-by: Mannat Singh <13458796+mannatsingh@users.noreply.github.com> Co-authored-by: Piotr Dollar <699682+pdollar@users.noreply.github.com> --- pycls/core/config.py | 12 ++---------- pycls/core/io.py | 7 ++++--- pycls/core/trainer.py | 21 ++++++++++++++++----- tools/run_net.py | 2 +- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pycls/core/config.py b/pycls/core/config.py index 4528921..dac6ae3 100644 --- a/pycls/core/config.py +++ b/pycls/core/config.py @@ -9,7 +9,7 @@ import os -from pycls.core.io import cache_url, pathmgr +from pycls.core.io import pathmgr from yacs.config import CfgNode @@ -427,7 +427,7 @@ _C.register_deprecated_key("TRAIN.CHECKPOINT_PERIOD") -def assert_and_infer_cfg(cache_urls=True): +def assert_cfg(): """Checks config values invariants.""" err_str = "The first lr step must start at 0" assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str @@ -440,14 +440,6 @@ def assert_and_infer_cfg(cache_urls=True): assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str err_str = "Log destination '{}' not supported" assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) - if cache_urls: - cache_cfg_urls() - - -def cache_cfg_urls(): - """Download URLs in config, cache them, and rewrite cfg to use cached file.""" - _C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE) - _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE) def dump_cfg(): diff --git a/pycls/core/io.py b/pycls/core/io.py index ccd3b8d..86ef516 100644 --- a/pycls/core/io.py +++ b/pycls/core/io.py @@ -24,7 +24,7 @@ _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" -def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL): +def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL, download=True): """Download the file specified by the URL to the cache_dir and return the path to the cached file. If the argument is not a URL, simply return it as is. """ @@ -39,8 +39,9 @@ def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL): cache_file_dir = os.path.dirname(cache_file_path) if not pathmgr.exists(cache_file_dir): pathmgr.mkdirs(cache_file_dir) - logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) - download_url(url, cache_file_path) + if download: + logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) + download_url(url, cache_file_path) return cache_file_path diff --git a/pycls/core/trainer.py b/pycls/core/trainer.py index a53cd2c..5449bcd 100644 --- a/pycls/core/trainer.py +++ b/pycls/core/trainer.py @@ -25,7 +25,7 @@ import torch import torch.cuda.amp as amp from pycls.core.config import cfg -from pycls.core.io import pathmgr +from pycls.core.io import cache_url, pathmgr logger = logging.get_logger(__name__) @@ -76,6 +76,15 @@ def setup_model(): return model +def get_weights_file(weights_file): + """Download weights file if stored as a URL.""" + download = dist.is_master_proc() + weights_file = cache_url(weights_file, cfg.DOWNLOAD_CACHE, download=download) + if cfg.NUM_GPUS > 1: + torch.distributed.barrier() + return weights_file + + def train_epoch(loader, model, ema, loss_fun, optimizer, scaler, meter, cur_epoch): """Performs one epoch of training.""" # Shuffle the data @@ -166,8 +175,9 @@ def train_model(): logger.info("Loaded checkpoint from: {}".format(file)) start_epoch = epoch + 1 elif cfg.TRAIN.WEIGHTS: - cp.load_checkpoint(cfg.TRAIN.WEIGHTS, model, ema) - logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS)) + train_weights = get_weights_file(cfg.TRAIN.WEIGHTS) + cp.load_checkpoint(train_weights, model, ema) + logger.info("Loaded initial weights from: {}".format(train_weights)) # Create data loaders and meters train_loader = data_loader.construct_train_loader() test_loader = data_loader.construct_test_loader() @@ -206,8 +216,9 @@ def test_model(): # Construct the model model = setup_model() # Load model weights - cp.load_checkpoint(cfg.TEST.WEIGHTS, model) - logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) + test_weights = get_weights_file(cfg.TEST.WEIGHTS) + cp.load_checkpoint(test_weights, model) + logger.info("Loaded model weights from: {}".format(test_weights)) # Create data loaders and meters test_loader = data_loader.construct_test_loader() test_meter = meters.TestMeter(len(test_loader)) diff --git a/tools/run_net.py b/tools/run_net.py index 0abe6a0..233768f 100755 --- a/tools/run_net.py +++ b/tools/run_net.py @@ -40,7 +40,7 @@ def main(): mode = args.mode config.load_cfg(args.cfg) cfg.merge_from_list(args.opts) - config.assert_and_infer_cfg() + config.assert_cfg() cfg.freeze() if mode == "info": print(builders.get_model()())