Skip to content

Commit

Permalink
Add multi-node support using submitit / SLURM (facebookresearch#157)
Browse files Browse the repository at this point in the history
Summary:
-docs/GETTING_STARTED.md: example usage
-config.py: added options for multi-node training
-distributed.py: added submitit support for local and slurm jobs
-trainer.py: minor tweaks to support multi-node execution
-requirements.txt: added submitit requirement

Pull Request resolved: facebookresearch#157

Reviewed By: pdollar

Differential Revision: D29630024

Pulled By: mannatsingh

fbshipit-source-id: 50e61ee56dc24fd2f9868687f77a18380732cb37

Co-authored-by: Mannat Singh <13458796+mannatsingh@users.noreply.github.com>
Co-authored-by: Piotr Dollar <699682+pdollar@users.noreply.github.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Jul 9, 2021
1 parent 750a36e commit f8cd962
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 6 deletions.
9 changes: 9 additions & 0 deletions docs/GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ The examples below use a config for RegNetX-400MF on ImageNet with 8 GPUs.
OUT_DIR /tmp
```

### Model Evaluation (multi-node)

```
./tools/run_net.py --mode test \
--cfg configs/dds_baselines/regnetx/RegNetX-400MF_dds_8gpu.yaml \
TEST.WEIGHTS https://dl.fbaipublicfiles.com/pycls/dds_baselines/160905967/RegNetX-400MF_dds_8gpu.pyth \
OUT_DIR test/ LOG_DEST file LAUNCH.MODE slurm LAUNCH.PARTITION devlab NUM_GPUS 16 LAUNCH.NAME pycls_eval_test
```

### Model Training

```
Expand Down
31 changes: 31 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,29 @@
_C.PREC_TIME.NUM_ITER = 30


# ---------------------------------- Launch options ---------------------------------- #
_C.LAUNCH = CfgNode()

# The launch mode, may be 'local' or 'slurm' (or 'submitit_local' for debugging)
# The 'local' mode uses a multi-GPU setup via torch.multiprocessing.run_processes.
# The 'slurm' mode uses submitit to launch a job on a SLURM cluster and provides
# support for MULTI-NODE jobs (and is the only way to launch MULTI-NODE jobs).
# In 'slurm' mode, the LAUNCH options below can be used to control the SLURM options.
# Note that NUM_GPUS (not part of LAUNCH options) determines total GPUs requested.
_C.LAUNCH.MODE = "local"

# Launch options that are only used if LAUNCH.MODE is 'slurm'
_C.LAUNCH.MAX_RETRY = 3
_C.LAUNCH.NAME = "pycls_job"
_C.LAUNCH.COMMENT = ""
_C.LAUNCH.CPUS_PER_GPU = 10
_C.LAUNCH.MEM_PER_GPU = 60
_C.LAUNCH.PARTITION = "devlab"
_C.LAUNCH.GPU_TYPE = "volta"
_C.LAUNCH.TIME_LIMIT = 4200
_C.LAUNCH.EMAIL = ""


# ----------------------------------- Misc options ----------------------------------- #
# Optional description of a config
_C.DESC = ""
Expand All @@ -385,6 +408,9 @@
# Number of GPUs to use (applies to both training and testing)
_C.NUM_GPUS = 1

# Maximum number of GPUs available per node (unlikely to need to be changed)
_C.MAX_GPUS_PER_NODE = 8

# Output directory
_C.OUT_DIR = "/tmp"

Expand Down Expand Up @@ -440,6 +466,11 @@ def assert_cfg():
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
err_str = "Log destination '{}' not supported"
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
err_str = "NUM_GPUS must be divisible by or less than MAX_GPUS_PER_NODE"
num_gpus, max_gpus_per_node = _C.NUM_GPUS, _C.MAX_GPUS_PER_NODE
assert num_gpus <= max_gpus_per_node or num_gpus % max_gpus_per_node == 0, err_str
err_str = "Invalid mode {}".format(_C.LAUNCH.MODE)
assert _C.LAUNCH.MODE in ["local", "submitit_local", "slurm"], err_str


def dump_cfg():
Expand Down
55 changes: 52 additions & 3 deletions pycls/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import random

import submitit
import torch
from pycls.core.config import cfg

Expand All @@ -18,15 +19,37 @@
os.environ["MKL_THREADING_LAYER"] = "GNU"


def is_master_proc():
class SubmititRunner(submitit.helpers.Checkpointable):
"""A callable which is passed to submitit to launch the jobs."""

def __init__(self, port, fun, cfg_state):
self.cfg_state = cfg_state
self.port = port
self.fun = fun

def __call__(self):
job_env = submitit.JobEnvironment()
os.environ["MASTER_ADDR"] = job_env.hostnames[0]
os.environ["MASTER_PORT"] = str(self.port)
os.environ["RANK"] = str(job_env.global_rank)
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
setup_distributed(self.cfg_state)
self.fun()


def is_master_proc(local=False):
"""
Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints. In
the multi GPU setting, we assign the master role to the rank 0 process. When
training using a single GPU, there is a single process which is considered master.
If local==True, then check if the current process is the master on the current node.
"""
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
m = cfg.MAX_GPUS_PER_NODE if local else cfg.NUM_GPUS
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() % m == 0


def scaled_all_reduce(tensors):
Expand Down Expand Up @@ -85,7 +108,33 @@ def single_proc_run(local_rank, fun, master_port, cfg_state, world_size):

def multi_proc_run(num_proc, fun):
"""Run a single or multi GPU job locally on the current node."""
if num_proc > 1:
launch = cfg.LAUNCH
if launch.MODE in ["submitit_local", "slurm"]:
# Launch fun() using submitit either locally or on SLURM
use_slurm = launch.MODE == "slurm"
executor = submitit.AutoExecutor if use_slurm else submitit.LocalExecutor
kwargs = {"slurm_max_num_timeout": launch.MAX_RETRY} if use_slurm else {}
executor = executor(folder=cfg.OUT_DIR, **kwargs)
num_gpus_per_node = min(cfg.NUM_GPUS, cfg.MAX_GPUS_PER_NODE)
executor.update_parameters(
mem_gb=launch.MEM_PER_GPU * num_gpus_per_node,
gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node,
cpus_per_task=launch.CPUS_PER_GPU,
nodes=max(1, cfg.NUM_GPUS // cfg.MAX_GPUS_PER_NODE),
timeout_min=launch.TIME_LIMIT,
name=launch.NAME,
slurm_partition=launch.PARTITION,
slurm_comment=launch.COMMENT,
slurm_constraint=launch.GPU_TYPE,
slurm_additional_parameters={"mail-user": launch.EMAIL, "mail-type": "END"},
)
master_port = random.randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1])
job = executor.submit(SubmititRunner(master_port, fun, cfg))
print("Submitted job_id {} with out_dir: {}".format(job.job_id, cfg.OUT_DIR))
if not use_slurm:
job.wait()
elif num_proc > 1:
master_port = random.randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1])
mp_runner = torch.multiprocessing.start_processes
args = (fun, master_port, cfg, num_proc)
Expand Down
4 changes: 1 addition & 3 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def setup_model():
# Log model complexity
logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
# Transfer the model to the current GPU device
err_str = "Cannot use more GPU devices than available"
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
cur_device = torch.cuda.current_device()
model = model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
Expand All @@ -78,7 +76,7 @@ def setup_model():

def get_weights_file(weights_file):
"""Download weights file if stored as a URL."""
download = dist.is_master_proc()
download = dist.is_master_proc(local=True)
weights_file = cache_url(weights_file, cfg.DOWNLOAD_CACHE, download=download)
if cfg.NUM_GPUS > 1:
torch.distributed.barrier()
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ opencv-python==4.2.0.34
parameterized
setuptools
simplejson
submitit
yacs
yattag

0 comments on commit f8cd962

Please sign in to comment.