Skip to content

Commit

Permalink
minor refactor to get rid of cache_cfg_urls (facebookresearch#155)
Browse files Browse the repository at this point in the history
Summary:
The reason for this refactor is that cache_cfg_urls() alter the global
config and does not play nicely with multi-node training (coming soon).
So now file caching occurs inside of trainer.py, where/when needed.

Pull Request resolved: facebookresearch#155

Reviewed By: mannatsingh

Differential Revision: D29627547

Pulled By: pdollar

fbshipit-source-id: f397ec9ebe38733d204b95cca7c69618aeadba53

Co-authored-by: Mannat Singh <13458796+mannatsingh@users.noreply.github.com>
Co-authored-by: Piotr Dollar <699682+pdollar@users.noreply.github.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Jul 9, 2021
1 parent 9721320 commit 750a36e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
12 changes: 2 additions & 10 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import os

from pycls.core.io import cache_url, pathmgr
from pycls.core.io import pathmgr
from yacs.config import CfgNode


Expand Down Expand Up @@ -427,7 +427,7 @@
_C.register_deprecated_key("TRAIN.CHECKPOINT_PERIOD")


def assert_and_infer_cfg(cache_urls=True):
def assert_cfg():
"""Checks config values invariants."""
err_str = "The first lr step must start at 0"
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
Expand All @@ -440,14 +440,6 @@ def assert_and_infer_cfg(cache_urls=True):
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
err_str = "Log destination '{}' not supported"
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
if cache_urls:
cache_cfg_urls()


def cache_cfg_urls():
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)


def dump_cfg():
Expand Down
7 changes: 4 additions & 3 deletions pycls/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"


def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL):
def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL, download=True):
"""Download the file specified by the URL to the cache_dir and return the path to
the cached file. If the argument is not a URL, simply return it as is.
"""
Expand All @@ -39,8 +39,9 @@ def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL):
cache_file_dir = os.path.dirname(cache_file_path)
if not pathmgr.exists(cache_file_dir):
pathmgr.mkdirs(cache_file_dir)
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
if download:
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path


Expand Down
21 changes: 16 additions & 5 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch
import torch.cuda.amp as amp
from pycls.core.config import cfg
from pycls.core.io import pathmgr
from pycls.core.io import cache_url, pathmgr


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -76,6 +76,15 @@ def setup_model():
return model


def get_weights_file(weights_file):
"""Download weights file if stored as a URL."""
download = dist.is_master_proc()
weights_file = cache_url(weights_file, cfg.DOWNLOAD_CACHE, download=download)
if cfg.NUM_GPUS > 1:
torch.distributed.barrier()
return weights_file


def train_epoch(loader, model, ema, loss_fun, optimizer, scaler, meter, cur_epoch):
"""Performs one epoch of training."""
# Shuffle the data
Expand Down Expand Up @@ -166,8 +175,9 @@ def train_model():
logger.info("Loaded checkpoint from: {}".format(file))
start_epoch = epoch + 1
elif cfg.TRAIN.WEIGHTS:
cp.load_checkpoint(cfg.TRAIN.WEIGHTS, model, ema)
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
train_weights = get_weights_file(cfg.TRAIN.WEIGHTS)
cp.load_checkpoint(train_weights, model, ema)
logger.info("Loaded initial weights from: {}".format(train_weights))
# Create data loaders and meters
train_loader = data_loader.construct_train_loader()
test_loader = data_loader.construct_test_loader()
Expand Down Expand Up @@ -206,8 +216,9 @@ def test_model():
# Construct the model
model = setup_model()
# Load model weights
cp.load_checkpoint(cfg.TEST.WEIGHTS, model)
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
test_weights = get_weights_file(cfg.TEST.WEIGHTS)
cp.load_checkpoint(test_weights, model)
logger.info("Loaded model weights from: {}".format(test_weights))
# Create data loaders and meters
test_loader = data_loader.construct_test_loader()
test_meter = meters.TestMeter(len(test_loader))
Expand Down
2 changes: 1 addition & 1 deletion tools/run_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
mode = args.mode
config.load_cfg(args.cfg)
cfg.merge_from_list(args.opts)
config.assert_and_infer_cfg()
config.assert_cfg()
cfg.freeze()
if mode == "info":
print(builders.get_model()())
Expand Down

0 comments on commit 750a36e

Please sign in to comment.