diff --git a/pycls/core/distributed.py b/pycls/core/distributed.py index 3752988..eac168a 100644 --- a/pycls/core/distributed.py +++ b/pycls/core/distributed.py @@ -7,12 +7,8 @@ """Distributed helpers.""" -import multiprocessing import os import random -import signal -import threading -import traceback import torch from pycls.core.config import cfg @@ -23,7 +19,8 @@ def is_master_proc(): - """Determines if the current process is the master process. + """ + Determines if the current process is the master process. Master process is responsible for logging, writing and loading checkpoints. In the multi GPU setting, we assign the master role to the rank 0 process. When @@ -32,26 +29,9 @@ def is_master_proc(): return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 -def init_process_group(proc_rank, world_size, port): - """Initializes the default process group.""" - # Set the GPU to use - torch.cuda.set_device(proc_rank) - # Initialize the process group - torch.distributed.init_process_group( - backend=cfg.DIST_BACKEND, - init_method="tcp://{}:{}".format(cfg.HOST, port), - world_size=world_size, - rank=proc_rank, - ) - - -def destroy_process_group(): - """Destroys the default process group.""" - torch.distributed.destroy_process_group() - - def scaled_all_reduce(tensors): - """Performs the scaled all_reduce operation on the provided tensors. + """ + Performs the scaled all_reduce operation on the provided tensors. The input tensors are modified in-place. Currently supports only the sum reduction operator. The reduced values are scaled by the inverse size of the @@ -74,91 +54,44 @@ def scaled_all_reduce(tensors): return tensors -class ChildException(Exception): - """Wraps an exception from a child process.""" - - def __init__(self, child_trace): - super(ChildException, self).__init__(child_trace) - +def setup_distributed(cfg_state): + """ + Initialize torch.distributed and set the CUDA device. -class ErrorHandler(object): - """Multiprocessing error handler (based on fairseq's). + Expects environment variables to be set as per + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. - Listens for errors in child processes and propagates the tracebacks to the parent. + This is run inside a new process, so the cfg is reset and must be set explicitly. """ - - def __init__(self, error_queue): - # Shared error queue - self.error_queue = error_queue - # Children processes sharing the error queue - self.children_pids = [] - # Start a thread listening to errors - self.error_listener = threading.Thread(target=self.listen, daemon=True) - self.error_listener.start() - # Register the signal handler - signal.signal(signal.SIGUSR1, self.signal_handler) - - def add_child(self, pid): - """Registers a child process.""" - self.children_pids.append(pid) - - def listen(self): - """Listens for errors in the error queue.""" - # Wait until there is an error in the queue - child_trace = self.error_queue.get() - # Put the error back for the signal handler - self.error_queue.put(child_trace) - # Invoke the signal handler - os.kill(os.getpid(), signal.SIGUSR1) - - def signal_handler(self, _sig_num, _stack_frame): - """Signal handler.""" - # Kill children processes - for pid in self.children_pids: - os.kill(pid, signal.SIGINT) - # Propagate the error from the child process - raise ChildException(self.error_queue.get()) - - -def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs): - """Runs a function from a child process.""" - try: - # Initialize the process group - init_process_group(proc_rank, world_size, port) - # Run the function - fun(*fun_args, **fun_kwargs) - except KeyboardInterrupt: - # Killed by the parent process - pass - except Exception: - # Propagate exception to the parent process - error_queue.put(traceback.format_exc()) - finally: - # Destroy the process group - destroy_process_group() - - -def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): - """Runs a function in a multi-proc setting (unless num_proc == 1).""" - # There is no need for multi-proc in the single-proc case - fun_kwargs = fun_kwargs if fun_kwargs else {} - if num_proc == 1: - fun(*fun_args, **fun_kwargs) - return - # Handle errors from training subprocesses - error_queue = multiprocessing.SimpleQueue() - error_handler = ErrorHandler(error_queue) - # Get a random port to use (without using global random number generator) - port = random.Random().randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1]) - # Run each training subprocess - ps = [] - for i in range(num_proc): - p_i = multiprocessing.Process( - target=run, args=(i, num_proc, port, error_queue, fun, fun_args, fun_kwargs) - ) - ps.append(p_i) - p_i.start() - error_handler.add_child(p_i.pid) - # Wait for each subprocess to finish - for p in ps: - p.join() + cfg.defrost() + cfg.update(**cfg_state) + cfg.freeze() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.distributed.init_process_group(backend=cfg.DIST_BACKEND) + torch.cuda.set_device(local_rank) + + +def single_proc_run(local_rank, fun, master_port, cfg_state, world_size): + """Executes fun() on a single GPU in a multi-GPU setup.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + setup_distributed(cfg_state) + fun() + + +def multi_proc_run(num_proc, fun): + """Run a single or multi GPU job locally on the current node.""" + if num_proc > 1: + master_port = random.randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1]) + mp_runner = torch.multiprocessing.start_processes + args = (fun, master_port, cfg, num_proc) + # Note: using "fork" below, "spawn" causes time and error regressions. Using + # spawn changes the default multiprocessing context to spawn, which doesn't + # interact well with the dataloaders (likely due to the use of OpenCV). + mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="fork") + else: + fun() diff --git a/pycls/core/trainer.py b/pycls/core/trainer.py index 33c642e..a53cd2c 100644 --- a/pycls/core/trainer.py +++ b/pycls/core/trainer.py @@ -7,6 +7,7 @@ """Tools for training and testing a model.""" +import os import random from copy import deepcopy @@ -42,6 +43,8 @@ def setup_env(): # Log torch, cuda, and cudnn versions version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()] logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) + env = "".join([f"{key}: {value}\n" for key, value in sorted(os.environ.items())]) + logger.info(f"os.environ:\n{env}") # Log the config as both human readable and as a json logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else () logger.info(logging.dump_log_data(cfg, "cfg", None))