Skip to content

Commit

Permalink
modernize distributed using torch.multiprocessing (facebookresearch#152)
Browse files Browse the repository at this point in the history
Summary:
-distributed.py: major refactor
-trainer.py: log environment

Pull Request resolved: facebookresearch#152

Reviewed By: pdollar

Differential Revision: D29572668

Pulled By: mannatsingh

fbshipit-source-id: c5cca03dc94b6f196123e5f8771853f70aabe5a2

Co-authored-by: Mannat Singh <13458796+mannatsingh@users.noreply.github.com>
Co-authored-by: Piotr Dollar <699682+pdollar@users.noreply.github.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Jul 7, 2021
1 parent 74df325 commit b0316d8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 109 deletions.
151 changes: 42 additions & 109 deletions pycls/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@

"""Distributed helpers."""

import multiprocessing
import os
import random
import signal
import threading
import traceback

import torch
from pycls.core.config import cfg
Expand All @@ -23,7 +19,8 @@


def is_master_proc():
"""Determines if the current process is the master process.
"""
Determines if the current process is the master process.
Master process is responsible for logging, writing and loading checkpoints. In
the multi GPU setting, we assign the master role to the rank 0 process. When
Expand All @@ -32,26 +29,9 @@ def is_master_proc():
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0


def init_process_group(proc_rank, world_size, port):
"""Initializes the default process group."""
# Set the GPU to use
torch.cuda.set_device(proc_rank)
# Initialize the process group
torch.distributed.init_process_group(
backend=cfg.DIST_BACKEND,
init_method="tcp://{}:{}".format(cfg.HOST, port),
world_size=world_size,
rank=proc_rank,
)


def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()


def scaled_all_reduce(tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
"""
Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of the
Expand All @@ -74,91 +54,44 @@ def scaled_all_reduce(tensors):
return tensors


class ChildException(Exception):
"""Wraps an exception from a child process."""

def __init__(self, child_trace):
super(ChildException, self).__init__(child_trace)

def setup_distributed(cfg_state):
"""
Initialize torch.distributed and set the CUDA device.
class ErrorHandler(object):
"""Multiprocessing error handler (based on fairseq's).
Expects environment variables to be set as per
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
Listens for errors in child processes and propagates the tracebacks to the parent.
This is run inside a new process, so the cfg is reset and must be set explicitly.
"""

def __init__(self, error_queue):
# Shared error queue
self.error_queue = error_queue
# Children processes sharing the error queue
self.children_pids = []
# Start a thread listening to errors
self.error_listener = threading.Thread(target=self.listen, daemon=True)
self.error_listener.start()
# Register the signal handler
signal.signal(signal.SIGUSR1, self.signal_handler)

def add_child(self, pid):
"""Registers a child process."""
self.children_pids.append(pid)

def listen(self):
"""Listens for errors in the error queue."""
# Wait until there is an error in the queue
child_trace = self.error_queue.get()
# Put the error back for the signal handler
self.error_queue.put(child_trace)
# Invoke the signal handler
os.kill(os.getpid(), signal.SIGUSR1)

def signal_handler(self, _sig_num, _stack_frame):
"""Signal handler."""
# Kill children processes
for pid in self.children_pids:
os.kill(pid, signal.SIGINT)
# Propagate the error from the child process
raise ChildException(self.error_queue.get())


def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs):
"""Runs a function from a child process."""
try:
# Initialize the process group
init_process_group(proc_rank, world_size, port)
# Run the function
fun(*fun_args, **fun_kwargs)
except KeyboardInterrupt:
# Killed by the parent process
pass
except Exception:
# Propagate exception to the parent process
error_queue.put(traceback.format_exc())
finally:
# Destroy the process group
destroy_process_group()


def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
# There is no need for multi-proc in the single-proc case
fun_kwargs = fun_kwargs if fun_kwargs else {}
if num_proc == 1:
fun(*fun_args, **fun_kwargs)
return
# Handle errors from training subprocesses
error_queue = multiprocessing.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Get a random port to use (without using global random number generator)
port = random.Random().randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1])
# Run each training subprocess
ps = []
for i in range(num_proc):
p_i = multiprocessing.Process(
target=run, args=(i, num_proc, port, error_queue, fun, fun_args, fun_kwargs)
)
ps.append(p_i)
p_i.start()
error_handler.add_child(p_i.pid)
# Wait for each subprocess to finish
for p in ps:
p.join()
cfg.defrost()
cfg.update(**cfg_state)
cfg.freeze()
local_rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group(backend=cfg.DIST_BACKEND)
torch.cuda.set_device(local_rank)


def single_proc_run(local_rank, fun, master_port, cfg_state, world_size):
"""Executes fun() on a single GPU in a multi-GPU setup."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
setup_distributed(cfg_state)
fun()


def multi_proc_run(num_proc, fun):
"""Run a single or multi GPU job locally on the current node."""
if num_proc > 1:
master_port = random.randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1])
mp_runner = torch.multiprocessing.start_processes
args = (fun, master_port, cfg, num_proc)
# Note: using "fork" below, "spawn" causes time and error regressions. Using
# spawn changes the default multiprocessing context to spawn, which doesn't
# interact well with the dataloaders (likely due to the use of OpenCV).
mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="fork")
else:
fun()
3 changes: 3 additions & 0 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

"""Tools for training and testing a model."""

import os
import random
from copy import deepcopy

Expand Down Expand Up @@ -42,6 +43,8 @@ def setup_env():
# Log torch, cuda, and cudnn versions
version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()]
logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version))
env = "".join([f"{key}: {value}\n" for key, value in sorted(os.environ.items())])
logger.info(f"os.environ:\n{env}")
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else ()
logger.info(logging.dump_log_data(cfg, "cfg", None))
Expand Down

0 comments on commit b0316d8

Please sign in to comment.