From 4fab913184b4acd2eeb039a5d9bc5a3b3194a759 Mon Sep 17 00:00:00 2001 From: Piotr Dollar Date: Wed, 17 Mar 2021 17:09:03 -0700 Subject: [PATCH] minor cleanup after io migration to PathManager (#130) Summary: best to view this commit using a diff from ca89a79 from 11/20/20 this diff looks big on its own, but it's MINIMAL versus ca89a79 this diff undos some of the orthogonal changes in previous few diffs -no longer using global path manager (which was deprecated) -cfg.merge_from_file replaced by load_cfg (old load_cfg unused) -correct / minimal requirement in requirements.txt -setup_env() moved back into trainer.py (should not be used elsewhere) -removed orthogonal changes to cifar10.py introduced in a prev commits Pull Request resolved: https://github.com/facebookresearch/pycls/pull/130 Reviewed By: theschnitz Differential Revision: D27128859 Pulled By: pdollar fbshipit-source-id: 1029adca3b1bf8c33fae1fc2f82bb658f34cb56b --- dev/model_zoo_tables.py | 4 ++-- dev/test_models.py | 10 ++++------ pycls/core/checkpoint.py | 31 +++++++++++++------------------ pycls/core/config.py | 23 ++++++++--------------- pycls/core/env.py | 39 ++++++--------------------------------- pycls/core/io.py | 13 ++++++++----- pycls/core/logging.py | 10 +++++----- pycls/core/trainer.py | 37 ++++++++++++++++++++++++++++++++----- pycls/datasets/cifar10.py | 12 ++++++------ pycls/models/model_zoo.py | 4 ++-- requirements.txt | 2 +- 11 files changed, 87 insertions(+), 98 deletions(-) diff --git a/dev/model_zoo_tables.py b/dev/model_zoo_tables.py index f7f07f5..a6ff88d 100755 --- a/dev/model_zoo_tables.py +++ b/dev/model_zoo_tables.py @@ -13,7 +13,7 @@ import pycls.core.builders as builders import pycls.core.net as net import pycls.models.model_zoo as model_zoo -from pycls.core.config import cfg, reset_cfg +from pycls.core.config import cfg, load_cfg, reset_cfg # Location of pycls directory @@ -47,7 +47,7 @@ def get_model_data(name, timings, errors): """Get model data for a single model.""" # Load model config reset_cfg() - cfg.merge_from_file(model_zoo.get_config_file(name)) + load_cfg(model_zoo.get_config_file(name)) config_url, _, model_id, _, weight_url_full = model_zoo.get_model_info(name) # Get model complexity cx = net.complexity(builders.get_model()) diff --git a/dev/test_models.py b/dev/test_models.py index 53d4cd2..eb5a873 100755 --- a/dev/test_models.py +++ b/dev/test_models.py @@ -16,13 +16,12 @@ import pycls.core.builders as builders import pycls.core.distributed as dist -import pycls.core.env as env import pycls.core.logging as logging import pycls.core.net as net import pycls.core.trainer as trainer import pycls.models.model_zoo as model_zoo from parameterized import parameterized -from pycls.core.config import cfg, merge_from_file, reset_cfg +from pycls.core.config import cfg, load_cfg, reset_cfg # Location of pycls directory @@ -38,14 +37,14 @@ def test_complexity(key): """Measure the complexity of a single model.""" reset_cfg() cfg_file = os.path.join(_PYCLS_DIR, key) - merge_from_file(cfg_file) + load_cfg(cfg_file) return net.complexity(builders.get_model()) def test_timing(key): """Measure the timing of a single model.""" reset_cfg() - merge_from_file(model_zoo.get_config_file(key)) + load_cfg(model_zoo.get_config_file(key)) cfg.PREC_TIME.WARMUP_ITER, cfg.PREC_TIME.NUM_ITER = 5, 50 cfg.OUT_DIR, cfg.LOG_DEST = tempfile.mkdtemp(), "file" dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model) @@ -58,7 +57,7 @@ def test_timing(key): def test_error(key): """Measure the error of a single model.""" reset_cfg() - merge_from_file(model_zoo.get_config_file(key)) + load_cfg(model_zoo.get_config_file(key)) cfg.TEST.WEIGHTS = model_zoo.get_weights_file(key) cfg.OUT_DIR, cfg.LOG_DEST = tempfile.mkdtemp(), "file" dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model) @@ -138,7 +137,6 @@ class TestError(unittest.TestCase): @parameterized.expand(parse_tests(load_test_data("error")), skip_on_empty=True) @unittest.skipIf(not _RUN_ERROR_TESTS, "Skipping error tests") def test(self, key, out_expected): - env.setup_env() print("\nTesting error of: {}".format(key)) out = test_error(key) print("expected = {}".format(out_expected)) diff --git a/pycls/core/checkpoint.py b/pycls/core/checkpoint.py index aa2296e..371b172 100644 --- a/pycls/core/checkpoint.py +++ b/pycls/core/checkpoint.py @@ -8,12 +8,11 @@ """Functions that handle saving and loading of checkpoints.""" import os -from shutil import copyfileobj import pycls.core.distributed as dist import torch -from iopath.common.file_io import g_pathmgr from pycls.core.config import cfg +from pycls.core.io import pathmgr from pycls.core.net import unwrap_model @@ -43,7 +42,7 @@ def get_checkpoint_best(): def get_last_checkpoint(): """Retrieves the most recent checkpoint (highest epoch number).""" checkpoint_dir = get_checkpoint_dir() - checkpoints = [f for f in g_pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] + checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] last_checkpoint_name = sorted(checkpoints)[-1] return os.path.join(checkpoint_dir, last_checkpoint_name) @@ -51,9 +50,9 @@ def get_last_checkpoint(): def has_checkpoint(): """Determines if there are checkpoints available.""" checkpoint_dir = get_checkpoint_dir() - if not g_pathmgr.exists(checkpoint_dir): + if not pathmgr.exists(checkpoint_dir): return False - return any(_NAME_PREFIX in f for f in g_pathmgr.ls(checkpoint_dir)) + return any(_NAME_PREFIX in f for f in pathmgr.ls(checkpoint_dir)) def save_checkpoint(model, optimizer, epoch, best): @@ -62,7 +61,7 @@ def save_checkpoint(model, optimizer, epoch, best): if not dist.is_master_proc(): return # Ensure that the checkpoint dir exists - g_pathmgr.mkdirs(get_checkpoint_dir()) + pathmgr.mkdirs(get_checkpoint_dir()) # Record the state checkpoint = { "epoch": epoch, @@ -72,21 +71,19 @@ def save_checkpoint(model, optimizer, epoch, best): } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) - with g_pathmgr.open(checkpoint_file, "wb") as f: + with pathmgr.open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) # If best copy checkpoint to the best checkpoint if best: - with g_pathmgr.open(checkpoint_file, "rb") as src: - with g_pathmgr.open(get_checkpoint_best(), "wb") as dst: - copyfileobj(src, dst) + pathmgr.copy(checkpoint_file, get_checkpoint_best()) return checkpoint_file def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" - assert g_pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) - with g_pathmgr.open(checkpoint_file, "rb") as f: + assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) + with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") unwrap_model(model).load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () @@ -97,12 +94,10 @@ def delete_checkpoints(checkpoint_dir=None, keep="all"): """Deletes unneeded checkpoints, keep can be "all", "last", or "none".""" assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep) checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir() - if keep == "all" or not g_pathmgr.exists(checkpoint_dir): + if keep == "all" or not pathmgr.exists(checkpoint_dir): return 0 - checkpoints = [f for f in g_pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] + checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints - [ - g_pathmgr.rm(os.path.join(checkpoint_dir, checkpoint)) - for checkpoint in checkpoints - ] + for checkpoint in checkpoints: + pathmgr.rm(os.path.join(checkpoint_dir, checkpoint)) return len(checkpoints) diff --git a/pycls/core/config.py b/pycls/core/config.py index 30e8c48..370029a 100644 --- a/pycls/core/config.py +++ b/pycls/core/config.py @@ -11,8 +11,7 @@ import os import sys -from iopath.common.file_io import g_pathmgr -from pycls.core.io import cache_url +from pycls.core.io import cache_url, pathmgr from yacs.config import CfgNode as CfgNode @@ -378,28 +377,22 @@ def cache_cfg_urls(): _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE) -def merge_from_file(cfg_file): - with g_pathmgr.open(cfg_file, "r") as f: - cfg = _C.load_cfg(f) - _C.merge_from_other_cfg(cfg) - - def dump_cfg(): """Dumps the config to the output directory.""" cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) - with g_pathmgr.open(cfg_file, "w") as f: + with pathmgr.open(cfg_file, "w") as f: _C.dump(stream=f) -def load_cfg(out_dir, cfg_dest="config.yaml"): - """Loads config from specified output directory.""" - cfg_file = os.path.join(out_dir, cfg_dest) - merge_from_file(cfg_file) +def load_cfg(cfg_file): + """Loads config from specified file.""" + with pathmgr.open(cfg_file, "r") as f: + _C.merge_from_other_cfg(_C.load_cfg(f)) def reset_cfg(): """Reset config to initial state.""" - cfg.merge_from_other_cfg(_CFG_DEFAULT) + _C.merge_from_other_cfg(_CFG_DEFAULT) def load_cfg_fom_args(description="Config file options."): @@ -413,5 +406,5 @@ def load_cfg_fom_args(description="Config file options."): parser.print_help() sys.exit(1) args = parser.parse_args() - merge_from_file(args.cfg_file) + load_cfg(args.cfg_file) _C.merge_from_list(args.opts) diff --git a/pycls/core/env.py b/pycls/core/env.py index dde2581..37a147a 100644 --- a/pycls/core/env.py +++ b/pycls/core/env.py @@ -1,35 +1,8 @@ -import random +#!/usr/bin/env python3 -import numpy as np -import pycls.core.config as config -import pycls.core.distributed as dist -import pycls.core.logging as logging -import torch -from iopath.common.file_io import g_pathmgr -from pycls.core.config import cfg +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. - -logger = logging.get_logger(__name__) - - -def setup_env(): - """Sets up environment for training or testing.""" - if dist.is_master_proc(): - # Ensure that the output dir exists - g_pathmgr.mkdirs(cfg.OUT_DIR) - # Save the config - config.dump_cfg() - # Setup logging - logging.setup_logging() - # Log torch, cuda, and cudnn versions - version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()] - logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) - # Log the config as both human readable and as a json - logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else () - logger.info(logging.dump_log_data(cfg, "cfg", None)) - # Fix the RNG seeds (see RNG comment in core/config.py for discussion) - np.random.seed(cfg.RNG_SEED) - torch.manual_seed(cfg.RNG_SEED) - random.seed(cfg.RNG_SEED) - # Configure the CUDNN backend - torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK +"""This file is obsolete and will be removed in a future commit.""" diff --git a/pycls/core/io.py b/pycls/core/io.py index 0f19a37..ccd3b8d 100644 --- a/pycls/core/io.py +++ b/pycls/core/io.py @@ -13,9 +13,12 @@ import sys from urllib import request as urlrequest -from iopath.common.file_io import g_pathmgr +from iopath.common.file_io import PathManagerFactory +# instantiate global path manager for pycls +pathmgr = PathManagerFactory.get() + logger = logging.getLogger(__name__) _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" @@ -31,11 +34,11 @@ def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL): url = url_or_file assert url.startswith(base_url), "url must start with: {}".format(base_url) cache_file_path = url.replace(base_url, cache_dir) - if g_pathmgr.exists(cache_file_path): + if pathmgr.exists(cache_file_path): return cache_file_path cache_file_dir = os.path.dirname(cache_file_path) - if not g_pathmgr.exists(cache_file_dir): - g_pathmgr.mkdirs(cache_file_dir) + if not pathmgr.exists(cache_file_dir): + pathmgr.mkdirs(cache_file_dir) logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) download_url(url, cache_file_path) return cache_file_path @@ -66,7 +69,7 @@ def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_ba total_size = response.info().get("Content-Length").strip() total_size = int(total_size) bytes_so_far = 0 - with g_pathmgr.open(dst_file_path, "wb") as f: + with pathmgr.open(dst_file_path, "wb") as f: while 1: chunk = response.read(chunk_size) bytes_so_far += len(chunk) diff --git a/pycls/core/logging.py b/pycls/core/logging.py index 94d543a..993d1b9 100644 --- a/pycls/core/logging.py +++ b/pycls/core/logging.py @@ -15,8 +15,8 @@ import pycls.core.distributed as dist import simplejson -from iopath.common.file_io import g_pathmgr from pycls.core.config import cfg +from pycls.core.io import pathmgr # Show filename and line number in logs @@ -86,9 +86,9 @@ def float_to_decimal(data, prec=4): def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): """Get all log files in directory containing subdirs of trained models.""" - names = [n for n in sorted(g_pathmgr.ls(log_dir)) if name_filter in n] + names = [n for n in sorted(pathmgr.ls(log_dir)) if name_filter in n] files = [os.path.join(log_dir, n, log_file) for n in names] - f_n_ps = [(f, n) for (f, n) in zip(files, names) if g_pathmgr.exists(f)] + f_n_ps = [(f, n) for (f, n) in zip(files, names) if pathmgr.exists(f)] files, names = zip(*f_n_ps) if f_n_ps else ([], []) return files, names @@ -96,8 +96,8 @@ def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): def load_log_data(log_file, data_types_to_skip=()): """Loads log data into a dictionary of the form data[data_type][metric][index].""" # Load log_file - assert g_pathmgr.exists(log_file), "Log file not found: {}".format(log_file) - with g_pathmgr.open(log_file, "r") as f: + assert pathmgr.exists(log_file), "Log file not found: {}".format(log_file) + with pathmgr.open(log_file, "r") as f: lines = f.readlines() # Extract and parse lines that start with _TAG and have a type specified lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] diff --git a/pycls/core/trainer.py b/pycls/core/trainer.py index faf23ca..7037df4 100644 --- a/pycls/core/trainer.py +++ b/pycls/core/trainer.py @@ -6,12 +6,15 @@ # LICENSE file in the root directory of this source tree. """Tools for training and testing a model.""" + +import random + import numpy as np import pycls.core.benchmark as benchmark import pycls.core.builders as builders import pycls.core.checkpoint as cp +import pycls.core.config as config import pycls.core.distributed as dist -import pycls.core.env as env import pycls.core.logging as logging import pycls.core.meters as meters import pycls.core.net as net @@ -20,11 +23,35 @@ import torch import torch.cuda.amp as amp from pycls.core.config import cfg +from pycls.core.io import pathmgr logger = logging.get_logger(__name__) +def setup_env(): + """Sets up environment for training or testing.""" + if dist.is_master_proc(): + # Ensure that the output dir exists + pathmgr.mkdirs(cfg.OUT_DIR) + # Save the config + config.dump_cfg() + # Setup logging + logging.setup_logging() + # Log torch, cuda, and cudnn versions + version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()] + logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) + # Log the config as both human readable and as a json + logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else () + logger.info(logging.dump_log_data(cfg, "cfg", None)) + # Fix the RNG seeds (see RNG comment in core/config.py for discussion) + np.random.seed(cfg.RNG_SEED) + torch.manual_seed(cfg.RNG_SEED) + random.seed(cfg.RNG_SEED) + # Configure the CUDNN backend + torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK + + def setup_model(): """Sets up a model for training or testing and log the results.""" # Build the model @@ -118,7 +145,7 @@ def test_epoch(loader, model, meter, cur_epoch): def train_model(): """Trains the model.""" # Setup training/testing environment - env.setup_env() + setup_env() # Construct the model, loss_fun, and optimizer model = setup_model() loss_fun = builders.build_loss_fun().cuda() @@ -167,7 +194,7 @@ def train_model(): def test_model(): """Evaluates a trained model.""" # Setup training/testing environment - env.setup_env() + setup_env() # Construct the model model = setup_model() # Load model weights @@ -183,7 +210,7 @@ def test_model(): def time_model(): """Times model.""" # Setup training/testing environment - env.setup_env() + setup_env() # Construct the model and loss_fun model = setup_model() loss_fun = builders.build_loss_fun().cuda() @@ -194,7 +221,7 @@ def time_model(): def time_model_and_loader(): """Times model and data loader.""" # Setup training/testing environment - env.setup_env() + setup_env() # Construct the model and loss_fun model = setup_model() loss_fun = builders.build_loss_fun().cuda() diff --git a/pycls/datasets/cifar10.py b/pycls/datasets/cifar10.py index 1c294fd..6b36d53 100644 --- a/pycls/datasets/cifar10.py +++ b/pycls/datasets/cifar10.py @@ -13,8 +13,8 @@ import numpy as np import pycls.core.logging as logging import torch.utils.data -from iopath.common.file_io import g_pathmgr from pycls.core.config import cfg +from pycls.core.io import pathmgr logger = logging.get_logger(__name__) @@ -28,11 +28,10 @@ class Cifar10(torch.utils.data.Dataset): """CIFAR-10 dataset.""" def __init__(self, data_path, split): - assert g_pathmgr.exists(data_path), "Data path '{}' not found".format(data_path) + assert pathmgr.exists(data_path), "Data path '{}' not found".format(data_path) splits = ["train", "test"] assert split in splits, "Split '{}' not supported for cifar".format(split) logger.info("Constructing CIFAR-10 {}...".format(split)) - self._im_size = cfg.TRAIN.IM_SIZE self._data_path, self._split = data_path, split self._inputs, self._labels = self._load_data() @@ -48,13 +47,14 @@ def _load_data(self): inputs, labels = [], [] for batch_name in batch_names: batch_path = os.path.join(self._data_path, batch_name) - with g_pathmgr.open(batch_path, "rb") as f: + with pathmgr.open(batch_path, "rb") as f: data = pickle.load(f, encoding="bytes") inputs.append(data[b"data"]) labels += data[b"labels"] # Combine and reshape the inputs + assert cfg.TRAIN.IM_SIZE == 32, "CIFAR-10 images are 32x32" inputs = np.vstack(inputs).astype(np.float32) - inputs = inputs.reshape((-1, 3, self._im_size, self._im_size)) + inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE)) return inputs, labels def _prepare_im(self, im): @@ -64,7 +64,7 @@ def _prepare_im(self, im): im[i] = (im[i] - _MEAN[i]) / _STD[i] if self._split == "train": # Randomly flip and crop center patch from CHW image - size = self._im_size + size = cfg.TRAIN.IM_SIZE im = im[:, :, ::-1] if np.random.uniform() < 0.5 else im im = np.pad(im, ((0, 0), (4, 4), (4, 4)), mode="constant") y = np.random.randint(0, im.shape[1] - size) diff --git a/pycls/models/model_zoo.py b/pycls/models/model_zoo.py index cdd8f49..ae7447c 100644 --- a/pycls/models/model_zoo.py +++ b/pycls/models/model_zoo.py @@ -11,7 +11,7 @@ import pycls.core.builders as builders import pycls.core.checkpoint as cp -from pycls.core.config import cfg, reset_cfg +from pycls.core.config import cfg, load_cfg, reset_cfg from pycls.core.io import cache_url @@ -141,7 +141,7 @@ def build_model(name, pretrained=False, cfg_list=()): # Load the config reset_cfg() config_file = get_config_file(name) - cfg.merge_from_file(config_file) + load_cfg(config_file) cfg.merge_from_list(cfg_list) # Construct model model = builders.build_model() diff --git a/requirements.txt b/requirements.txt index afd5d1d..c72334c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ black==19.3b0 isort==4.3.21 -fvcore +iopath flake8 matplotlib numpy