Skip to content

Commit

Permalink
minor cleanup after io migration to PathManager (facebookresearch#130)
Browse files Browse the repository at this point in the history
Summary:
best to view this commit using a diff from ca89a79 from 11/20/20
this diff looks big on its own, but it's MINIMAL versus ca89a79
this diff undos some of the orthogonal changes in previous few diffs

-no longer using global path manager (which was deprecated)
-cfg.merge_from_file replaced by load_cfg (old load_cfg unused)
-correct / minimal requirement in requirements.txt
-setup_env() moved back into trainer.py (should not be used elsewhere)
-removed orthogonal changes to cifar10.py introduced in a prev commits

Pull Request resolved: facebookresearch#130

Reviewed By: theschnitz

Differential Revision: D27128859

Pulled By: pdollar

fbshipit-source-id: 1029adca3b1bf8c33fae1fc2f82bb658f34cb56b
  • Loading branch information
pdollar authored and facebook-github-bot committed Mar 18, 2021
1 parent b314d43 commit 4fab913
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 98 deletions.
4 changes: 2 additions & 2 deletions dev/model_zoo_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pycls.core.builders as builders
import pycls.core.net as net
import pycls.models.model_zoo as model_zoo
from pycls.core.config import cfg, reset_cfg
from pycls.core.config import cfg, load_cfg, reset_cfg


# Location of pycls directory
Expand Down Expand Up @@ -47,7 +47,7 @@ def get_model_data(name, timings, errors):
"""Get model data for a single model."""
# Load model config
reset_cfg()
cfg.merge_from_file(model_zoo.get_config_file(name))
load_cfg(model_zoo.get_config_file(name))
config_url, _, model_id, _, weight_url_full = model_zoo.get_model_info(name)
# Get model complexity
cx = net.complexity(builders.get_model())
Expand Down
10 changes: 4 additions & 6 deletions dev/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

import pycls.core.builders as builders
import pycls.core.distributed as dist
import pycls.core.env as env
import pycls.core.logging as logging
import pycls.core.net as net
import pycls.core.trainer as trainer
import pycls.models.model_zoo as model_zoo
from parameterized import parameterized
from pycls.core.config import cfg, merge_from_file, reset_cfg
from pycls.core.config import cfg, load_cfg, reset_cfg


# Location of pycls directory
Expand All @@ -38,14 +37,14 @@ def test_complexity(key):
"""Measure the complexity of a single model."""
reset_cfg()
cfg_file = os.path.join(_PYCLS_DIR, key)
merge_from_file(cfg_file)
load_cfg(cfg_file)
return net.complexity(builders.get_model())


def test_timing(key):
"""Measure the timing of a single model."""
reset_cfg()
merge_from_file(model_zoo.get_config_file(key))
load_cfg(model_zoo.get_config_file(key))
cfg.PREC_TIME.WARMUP_ITER, cfg.PREC_TIME.NUM_ITER = 5, 50
cfg.OUT_DIR, cfg.LOG_DEST = tempfile.mkdtemp(), "file"
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model)
Expand All @@ -58,7 +57,7 @@ def test_timing(key):
def test_error(key):
"""Measure the error of a single model."""
reset_cfg()
merge_from_file(model_zoo.get_config_file(key))
load_cfg(model_zoo.get_config_file(key))
cfg.TEST.WEIGHTS = model_zoo.get_weights_file(key)
cfg.OUT_DIR, cfg.LOG_DEST = tempfile.mkdtemp(), "file"
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model)
Expand Down Expand Up @@ -138,7 +137,6 @@ class TestError(unittest.TestCase):
@parameterized.expand(parse_tests(load_test_data("error")), skip_on_empty=True)
@unittest.skipIf(not _RUN_ERROR_TESTS, "Skipping error tests")
def test(self, key, out_expected):
env.setup_env()
print("\nTesting error of: {}".format(key))
out = test_error(key)
print("expected = {}".format(out_expected))
Expand Down
31 changes: 13 additions & 18 deletions pycls/core/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
"""Functions that handle saving and loading of checkpoints."""

import os
from shutil import copyfileobj

import pycls.core.distributed as dist
import torch
from iopath.common.file_io import g_pathmgr
from pycls.core.config import cfg
from pycls.core.io import pathmgr
from pycls.core.net import unwrap_model


Expand Down Expand Up @@ -43,17 +42,17 @@ def get_checkpoint_best():
def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
checkpoints = [f for f in g_pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f]
checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)


def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not g_pathmgr.exists(checkpoint_dir):
if not pathmgr.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in g_pathmgr.ls(checkpoint_dir))
return any(_NAME_PREFIX in f for f in pathmgr.ls(checkpoint_dir))


def save_checkpoint(model, optimizer, epoch, best):
Expand All @@ -62,7 +61,7 @@ def save_checkpoint(model, optimizer, epoch, best):
if not dist.is_master_proc():
return
# Ensure that the checkpoint dir exists
g_pathmgr.mkdirs(get_checkpoint_dir())
pathmgr.mkdirs(get_checkpoint_dir())
# Record the state
checkpoint = {
"epoch": epoch,
Expand All @@ -72,21 +71,19 @@ def save_checkpoint(model, optimizer, epoch, best):
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
with g_pathmgr.open(checkpoint_file, "wb") as f:
with pathmgr.open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
# If best copy checkpoint to the best checkpoint
if best:
with g_pathmgr.open(checkpoint_file, "rb") as src:
with g_pathmgr.open(get_checkpoint_best(), "wb") as dst:
copyfileobj(src, dst)
pathmgr.copy(checkpoint_file, get_checkpoint_best())
return checkpoint_file


def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
err_str = "Checkpoint '{}' not found"
assert g_pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
with g_pathmgr.open(checkpoint_file, "rb") as f:
assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
with pathmgr.open(checkpoint_file, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
unwrap_model(model).load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
Expand All @@ -97,12 +94,10 @@ def delete_checkpoints(checkpoint_dir=None, keep="all"):
"""Deletes unneeded checkpoints, keep can be "all", "last", or "none"."""
assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep)
checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir()
if keep == "all" or not g_pathmgr.exists(checkpoint_dir):
if keep == "all" or not pathmgr.exists(checkpoint_dir):
return 0
checkpoints = [f for f in g_pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f]
checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f]
checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints
[
g_pathmgr.rm(os.path.join(checkpoint_dir, checkpoint))
for checkpoint in checkpoints
]
for checkpoint in checkpoints:
pathmgr.rm(os.path.join(checkpoint_dir, checkpoint))
return len(checkpoints)
23 changes: 8 additions & 15 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import os
import sys

from iopath.common.file_io import g_pathmgr
from pycls.core.io import cache_url
from pycls.core.io import cache_url, pathmgr
from yacs.config import CfgNode as CfgNode


Expand Down Expand Up @@ -378,28 +377,22 @@ def cache_cfg_urls():
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)


def merge_from_file(cfg_file):
with g_pathmgr.open(cfg_file, "r") as f:
cfg = _C.load_cfg(f)
_C.merge_from_other_cfg(cfg)


def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with g_pathmgr.open(cfg_file, "w") as f:
with pathmgr.open(cfg_file, "w") as f:
_C.dump(stream=f)


def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
merge_from_file(cfg_file)
def load_cfg(cfg_file):
"""Loads config from specified file."""
with pathmgr.open(cfg_file, "r") as f:
_C.merge_from_other_cfg(_C.load_cfg(f))


def reset_cfg():
"""Reset config to initial state."""
cfg.merge_from_other_cfg(_CFG_DEFAULT)
_C.merge_from_other_cfg(_CFG_DEFAULT)


def load_cfg_fom_args(description="Config file options."):
Expand All @@ -413,5 +406,5 @@ def load_cfg_fom_args(description="Config file options."):
parser.print_help()
sys.exit(1)
args = parser.parse_args()
merge_from_file(args.cfg_file)
load_cfg(args.cfg_file)
_C.merge_from_list(args.opts)
39 changes: 6 additions & 33 deletions pycls/core/env.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,8 @@
import random
#!/usr/bin/env python3

import numpy as np
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.logging as logging
import torch
from iopath.common.file_io import g_pathmgr
from pycls.core.config import cfg
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


logger = logging.get_logger(__name__)


def setup_env():
"""Sets up environment for training or testing."""
if dist.is_master_proc():
# Ensure that the output dir exists
g_pathmgr.mkdirs(cfg.OUT_DIR)
# Save the config
config.dump_cfg()
# Setup logging
logging.setup_logging()
# Log torch, cuda, and cudnn versions
version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()]
logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version))
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else ()
logger.info(logging.dump_log_data(cfg, "cfg", None))
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
random.seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
"""This file is obsolete and will be removed in a future commit."""
13 changes: 8 additions & 5 deletions pycls/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
import sys
from urllib import request as urlrequest

from iopath.common.file_io import g_pathmgr
from iopath.common.file_io import PathManagerFactory


# instantiate global path manager for pycls
pathmgr = PathManagerFactory.get()

logger = logging.getLogger(__name__)

_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
Expand All @@ -31,11 +34,11 @@ def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL):
url = url_or_file
assert url.startswith(base_url), "url must start with: {}".format(base_url)
cache_file_path = url.replace(base_url, cache_dir)
if g_pathmgr.exists(cache_file_path):
if pathmgr.exists(cache_file_path):
return cache_file_path
cache_file_dir = os.path.dirname(cache_file_path)
if not g_pathmgr.exists(cache_file_dir):
g_pathmgr.mkdirs(cache_file_dir)
if not pathmgr.exists(cache_file_dir):
pathmgr.mkdirs(cache_file_dir)
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
download_url(url, cache_file_path)
return cache_file_path
Expand Down Expand Up @@ -66,7 +69,7 @@ def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_ba
total_size = response.info().get("Content-Length").strip()
total_size = int(total_size)
bytes_so_far = 0
with g_pathmgr.open(dst_file_path, "wb") as f:
with pathmgr.open(dst_file_path, "wb") as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
Expand Down
10 changes: 5 additions & 5 deletions pycls/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import pycls.core.distributed as dist
import simplejson
from iopath.common.file_io import g_pathmgr
from pycls.core.config import cfg
from pycls.core.io import pathmgr


# Show filename and line number in logs
Expand Down Expand Up @@ -86,18 +86,18 @@ def float_to_decimal(data, prec=4):

def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(g_pathmgr.ls(log_dir)) if name_filter in n]
names = [n for n in sorted(pathmgr.ls(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, log_file) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if g_pathmgr.exists(f)]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if pathmgr.exists(f)]
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
return files, names


def load_log_data(log_file, data_types_to_skip=()):
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
# Load log_file
assert g_pathmgr.exists(log_file), "Log file not found: {}".format(log_file)
with g_pathmgr.open(log_file, "r") as f:
assert pathmgr.exists(log_file), "Log file not found: {}".format(log_file)
with pathmgr.open(log_file, "r") as f:
lines = f.readlines()
# Extract and parse lines that start with _TAG and have a type specified
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
Expand Down
Loading

0 comments on commit 4fab913

Please sign in to comment.