forked from facebookresearch/pycls
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unify entry points into model code into run_net.py (facebookresearch#148
) Summary: The new usage is ./tools/run_net --mode MODE --cfg CFG_FILE Valid choices for MODE are {info, scale, test, time, train} Note that the info mode is new (and prints the model and complexity) See GETTING_STARTED for the new usage Details: -GETTING_STARTED.md: updated documentation -core/config.py: removed load_cfg_from_args (no longer used) -scaler.py: added documentation that was previously in scale_net.py -tools/run_net.py encompasses all the individual scripts now -tools/{scale, test, time, train}.py are all obsolete now -sweep_launch.py, swee_launch_job.py, sweep/config.py: added MODE Pull Request resolved: facebookresearch#148 Test Plan: Tested all the instructions in the getting started docs - ``` ./tools/run_net.py --mode info \ --cfg configs/dds_baselines/regnetx/RegNetX-400MF_dds_8gpu.yaml ``` ``` ./tools/run_net.py --mode test \ --cfg configs/dds_baselines/regnetx/RegNetX-400MF_dds_8gpu.yaml \ TEST.WEIGHTS https://dl.fbaipublicfiles.com/pycls/dds_baselines/160905967/RegNetX-400MF_dds_8gpu.pyth \ OUT_DIR /tmp ``` ``` ./tools/run_net.py --mode train \ --cfg configs/dds_baselines/regnetx/RegNetX-400MF_dds_8gpu.yaml \ OUT_DIR /tmp ``` ``` ./tools/run_net.py --mode train \ --cfg configs/dds_baselines/regnetx/RegNetX-400MF_dds_8gpu.yaml \ TRAIN.WEIGHTS https://dl.fbaipublicfiles.com/pycls/dds_baselines/160905967/RegNetX-400MF_dds_8gpu.pyth \ OUT_DIR /tmp ``` ``` ./tools/run_net.py --mode time \ --cfg configs/dds_baselines/regnetx/RegNetX-400MF_dds_8gpu.yaml \ NUM_GPUS 1 \ TRAIN.BATCH_SIZE 64 \ TEST.BATCH_SIZE 64 \ PREC_TIME.WARMUP_ITER 5 \ PREC_TIME.NUM_ITER 50 ``` ``` ./tools/run_net.py --mode scale \ --cfg configs/dds_baselines/regnety/RegNetY-4.0GF_dds_8gpu.yaml \ OUT_DIR ./ \ CFG_DEST "RegNetY-4.0GF_dds_8gpu_scaled.yaml" \ MODEL.SCALING_FACTOR 4.0 \ MODEL.SCALING_TYPE "d1_w8_g8_r1" ``` Also tested a sweep launch - ``` SWEEP_CFG=configs/sweeps/cifar/cifar_optim.yaml ./tools/sweep_setup.py --sweep-cfg $SWEEP_CFG ./tools/sweep_launch.py --sweep-cfg $SWEEP_CFG ``` Reviewed By: pdollar Differential Revision: D29275940 Pulled By: mannatsingh fbshipit-source-id: af463d014d259bf8483b981a57a2a85c10209252
- Loading branch information
1 parent
f20820e
commit 5b57451
Showing
11 changed files
with
110 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Execute various operations (train, test, time, etc.) on a classification model.""" | ||
|
||
import argparse | ||
import sys | ||
|
||
import pycls.core.builders as builders | ||
import pycls.core.config as config | ||
import pycls.core.distributed as dist | ||
import pycls.core.net as net | ||
import pycls.core.trainer as trainer | ||
import pycls.models.scaler as scaler | ||
from pycls.core.config import cfg | ||
|
||
|
||
def parse_args(): | ||
"""Parse command line options (mode and config).""" | ||
parser = argparse.ArgumentParser(description="Run a model.") | ||
help_s, choices = "Run mode", ["info", "train", "test", "time", "scale"] | ||
parser.add_argument("--mode", help=help_s, choices=choices, required=True, type=str) | ||
help_s = "Config file location" | ||
parser.add_argument("--cfg", help=help_s, required=True, type=str) | ||
help_s = "See pycls/core/config.py for all options" | ||
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER) | ||
if len(sys.argv) == 1: | ||
parser.print_help() | ||
sys.exit(1) | ||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
"""Execute operation (train, test, time, etc.).""" | ||
args = parse_args() | ||
mode = args.mode | ||
config.load_cfg(args.cfg) | ||
cfg.merge_from_list(args.opts) | ||
config.assert_and_infer_cfg() | ||
cfg.freeze() | ||
if mode == "info": | ||
print(builders.get_model()()) | ||
print("complexity:", net.complexity(builders.get_model())) | ||
elif mode == "train": | ||
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.train_model) | ||
elif mode == "test": | ||
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model) | ||
elif mode == "time": | ||
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model) | ||
elif mode == "scale": | ||
cfg.defrost() | ||
cx_orig = net.complexity(builders.get_model()) | ||
scaler.scale_model() | ||
cx_scaled = net.complexity(builders.get_model()) | ||
cfg_file = config.dump_cfg() | ||
print("Scaled config dumped to:", cfg_file) | ||
print("Original model complexity:", cx_orig) | ||
print("Scaled model complexity:", cx_scaled) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.