From 3a09aa597744aad3ce0a401b1848a140ff8c1214 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 24 Apr 2023 21:27:22 +0000 Subject: [PATCH] [c10d] Faster coalescing (#98793) ### Description The PR aims at reducing CPU overhead of context manager style coalescing. By "context manager style coalescing", we mean: Sync style: ``` with _coalescing_manager(): for i in range(num_coll): dist.all_reduce(tensors[i]) ``` Async style: ``` with _coalescing_manager(async_ops=True) as cm: for i in range(num_coll): dist.all_reduce(tensors[i]) cm.wait() ``` In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead. In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager. ### Tests In current PR, the "fast path" only applies to all-reduce. - Flattened 512M: 16.38 ms, including CPU time 131.21 us - Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us - New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us Hence a 4x reduction in CPU overhead (dependent on `num_coll`). Cc @mrshenli @kumpera @wanchaol @fegin Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793 Approved by: https://github.com/kumpera --- torch/csrc/distributed/c10d/Backend.hpp | 11 +- torch/csrc/distributed/c10d/ProcessGroup.hpp | 17 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 96 +++++--- .../distributed/c10d/ProcessGroupNCCL.hpp | 8 +- torch/csrc/distributed/c10d/Work.hpp | 1 + torch/csrc/distributed/c10d/init.cpp | 7 +- torch/distributed/__init__.py | 2 + torch/distributed/distributed_c10d.py | 230 ++++++++++++++---- .../_internal/distributed/distributed_test.py | 69 ++++++ .../distributed/multi_threaded_pg.py | 11 +- 10 files changed, 341 insertions(+), 111 deletions(-) diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index c7d49cf6acf81..728fd08ca013e 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -54,12 +54,15 @@ class TORCH_API Backend : public torch::CustomClassHolder { } virtual void startCoalescing() { - // no-op for backends that have not implemented startCoalescing + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), "does not implement startCoalescing")); } - virtual void endCoalescing( - std::vector>& /* reqs */) { - // no-op for backends that have not implemented endCoalescing + virtual c10::intrusive_ptr endCoalescing() { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), "does not implement endCoalescing")); } // Subclasses must override this method to return the backend name diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 008bb84827ac5..a5bbffedac129 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -104,19 +104,16 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends - if (getBackendType() == BackendType::NCCL) { - getBackend(deviceType)->startCoalescing(); - } + auto backend = getBackend(deviceType); + backend->startCoalescing(); } - virtual void endCoalescing( - c10::DeviceType deviceType, - std::vector>& reqs) { - // only nccl has implemented startCoalescing so only execute for nccl + virtual c10::intrusive_ptr endCoalescing(c10::DeviceType deviceType) { + // only nccl has implemented endCoalescing so only execute for nccl // backends - if (getBackendType() == BackendType::NCCL) { - getBackend(deviceType)->endCoalescing(reqs); - } + auto backend = getBackend(deviceType); + auto work = backend->endCoalescing(); + return work; } virtual c10::intrusive_ptr broadcast( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index badeedeaa4d93..38610d5712db0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1389,44 +1389,62 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) : Backend::Options(NCCL_BACKEND_NAME), is_high_priority_stream(is_high_priority_stream) {} +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; + void ProcessGroupNCCL::startCoalescing() { coalescedDevices_.clear(); - coalescing_active_ = true; + coalescedComms_.clear(); + coalescing_state_ |= CoalActive; groupStart(); } -void ProcessGroupNCCL::endCoalescing( - std::vector>& reqs) { - if (!nccl_use_nonblocking()) { +c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { + if (!nccl_use_nonblocking() || + coalescedComms_.size() == 0) { // There is no actual work being coalesced groupEnd(); } else { - std::vector> ncclComms_; - for (const auto& req : reqs) { - auto ncclWork = static_cast(req.get()); - ncclComms_.insert( - ncclComms_.end(), - ncclWork->ncclComms_.begin(), - ncclWork->ncclComms_.end()); - } - groupEndNonblocking(ncclComms_); + // `coalescedComms_` should have same set of comms across collectives + auto comms = coalescedComms_[0]; + groupEndNonblocking(comms); } - if (reqs.size() != coalescedDevices_.size()) { - TORCH_CHECK(false, "Number of requests do not match number of collectives"); + + coalescing_state_ = 0; + + if (coalescedDevices_.size() == 0) { + // There is no actual work being coalesced + return nullptr; } - int batch_idx = 0; - for (const auto& req : reqs) { - auto ncclWork = static_cast(req.get()); - // @lint-ignore CLANGTIDY - std::vector devices = coalescedDevices_[batch_idx]; - const auto key = getKeyFromDevices(devices); - auto& ncclStreams = ncclStreams_[key]; - for (const auto i : c10::irange(devices.size())) { - (*ncclWork->ncclEndEvents_)[i].record(ncclStreams[i]); - } - batch_idx += 1; + // `coalescedDevices_` should have same set of devices across collectives + auto devices = coalescedDevices_[0]; + + // Create Work object + auto work = initWork( + devices, rank_, OpType::COALESCED, "nccl:coalesced", c10::nullopt); + + // Record stream event + // `getKeyFromDevices` is how we get keys for both collectives and batch P2P + const auto key = getKeyFromDevices(devices); + auto& ncclStreams = ncclStreams_[key]; + for (const auto i : c10::irange(devices.size())) { + auto& devEvent = (*work->ncclEndEvents_)[i]; + devEvent.record(ncclStreams[i]); } - coalescing_active_ = false; + + // Set appropriate work parameters. + work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams_; + work->opTimeout_ = options_->timeout; + work->store_ = store_; + + if (coalescing_state_ & CoalColl) { + workEnqueue(work); + // TODO: it seems we never enqueue work for single send/recv or batch P2P, + // see the `pointToPoint` function. This should be fixed. Otherwise, we risk + // not being able to abort hanged P2P ops. + } + + return work; } template @@ -1460,8 +1478,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( const auto key = getKeyFromDevices(devices); auto& ncclComms = getNCCLComm(key, devices, opType); - if (coalescing_active_) { + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; coalescedDevices_.push_back(devices); + coalescedComms_.push_back(ncclComms); } // Used many times below, so we stash the unordered_map lookup @@ -1547,7 +1567,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // End event should only be recorded after the ncclGroupEnd() for (const auto i : c10::irange(devices.size())) { at::cuda::CUDAStream& ncclStream = ncclStreams[i]; - if (!coalescing_active_) { + if (!coalescing_state_) { (*work->ncclEndEvents_)[i].record(ncclStream); } work->ncclComms_[i] = ncclComms[i]; @@ -1575,7 +1595,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( work->opTimeout_ = options_->timeout; work->store_ = store_; - workEnqueue(work); + if (!coalescing_state_) { + workEnqueue(work); + } return work; } @@ -1623,8 +1645,10 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank, isSendRecvSelf); - if (coalescing_active_) { + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalP2P; coalescedDevices_.push_back(devices); + coalescedComms_.push_back(ncclComms); } // First let NCCL streams wait for input tensors allocation streams @@ -1707,7 +1731,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // End event should only be recorded after the ncclGroupEnd() for (const auto i : c10::irange(tensors.size())) { at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; - if (!coalescing_active_) { + if (!coalescing_state_) { (*work->ncclEndEvents_)[i].record(ncclStream); } work->ncclComms_[i] = ncclComms[i]; @@ -2165,8 +2189,8 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( _broadcast_oop(outputs_multi_dev, inputs_multi_dev, broadcastOpts); works.push_back(work); } - endCoalescing(works); - return initCoalescedWork(works, rank_, OpType::BROADCAST); + auto work = endCoalescing(); + return work; } } @@ -2290,8 +2314,8 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( auto work = _reduce_oop(outputs_multi_dev, inputs_multi_dev, reduceOpts); works.push_back(work); } - endCoalescing(works); - return initCoalescedWork(works, rank_, OpType::REDUCE); + auto work = endCoalescing(); + return work; } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 5378e69cabf82..bd293b6855177 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -324,8 +324,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { void startCoalescing() override; - void endCoalescing( - std::vector>& reqs) override; + c10::intrusive_ptr endCoalescing() override; c10::intrusive_ptr broadcast( std::vector& tensors, @@ -649,11 +648,14 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::set usedDeviceIdxs_; // Flag to denote if a coalescing groupStart/groupEnd block is active - bool coalescing_active_ = false; + int coalescing_state_ = 0; // Stores device indexes for all collectives run inside a coalescing block std::vector> coalescedDevices_; + // Stores communicators for all collectives run inside a coalescing block + std::vector>> coalescedComms_; + // map from the key: "group name + pg counter (ID)" to the // unique NCCL ID count. This needs to be group and pg specific // diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index 212ed3041457f..2a827f15ca6c3 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -28,6 +28,7 @@ enum class OpType : std::uint8_t { RECVANYSOURCE = 14, BARRIER = 15, _REDUCE_SCATTER_BASE = 16, + COALESCED = 17, UNKNOWN = 100, }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 9e82f5d77fdbc..5aba3398430c6 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1467,12 +1467,10 @@ that adds a prefix to each key inserted to the store. .def( "_end_coalescing", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - const c10::Device& device, - std::vector>& reqs) { - self->endCoalescing(device.type(), reqs); + const c10::Device& device) { + return self->endCoalescing(device.type()); }, py::arg("device_type"), - py::arg("reqs"), py::call_guard()) .def( "_register_backend", @@ -1832,7 +1830,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). .def( "_end_coalescing", &::c10d::Backend::endCoalescing, - py::arg("reqs"), py::call_guard()); #ifdef USE_C10D_GLOO diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index c9f07a2e78253..669d7040fa65d 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -68,6 +68,8 @@ def is_available() -> bool: _create_process_group_wrapper, _rank_not_in_group, _c10d_error_logger, + _coalescing_manager, + _CoalescingManager, ) from .rendezvous import ( diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 715f136a290fd..8f4fd4c2de89e 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -10,7 +10,7 @@ import warnings from collections import namedtuple from datetime import timedelta -from typing import Any, Dict, Optional, Tuple, Union, List +from typing import Any, Callable, Dict, Optional, Tuple, Union, List import torch from torch._C._distributed_c10d import ( @@ -326,6 +326,61 @@ def __getattribute__(self, key): reduce_op = _reduce_op() +class P2POp: + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Args: + op (Callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``torch.distributed.isend`` or + ``torch.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with recv. + """ + + def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, + group: Optional[ProcessGroup] = None, tag: int = 0): + self.op = op + self.tensor = tensor + self.peer = peer + self.group = group + self.tag = tag + + def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, + group: Optional[ProcessGroup] = None, tag: int = 0): + _check_op(op) + _check_single_tensor(tensor, "tensor") + return object.__new__(cls) + + +class _CollOp: + """ + A class to capture collective operations. + + Args: + op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``. + tensor (Tensor): Tensor to operate on. + dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same. + redop (ReduceOp, optional): reduce operation. + root (int, optional): root of broadcast or reduce. + """ + + def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torch.Tensor] = None, + redop: Optional[ReduceOp] = None, root: Optional[int] = None): + self.op = op + self.tensor = tensor + self.dst_tensor = dst_tensor + self.redop = redop + self.root = root + + # DO NOT USE THESE FIELDS DIRECTLY. # Use them through the _world object to make sure the _world override mechanism _pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {} @@ -337,6 +392,7 @@ def __getattribute__(self, key): _tags_to_pg: Dict[str, List[ProcessGroup]] = {} _pg_to_tag: Dict[ProcessGroup, str] = {} + class _World: """ Container class for c10d process group state. @@ -347,6 +403,7 @@ class _World: """ def __init__(self): self._default_pg = None + self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {} @property def default_pg(self): @@ -428,6 +485,10 @@ def pg_to_tag(self) -> Dict[ProcessGroup, str]: global _pg_to_tag return _pg_to_tag + @property + def pg_coalesce_state(self) -> Dict[ProcessGroup, List[Union[_CollOp, P2POp]]]: + return self._pg_coalesce_state + _world = _World() """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" @@ -1186,6 +1247,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): _world.pg_backend_config.clear() _world.pg_to_tag.clear() _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() # when process group doesn't have an explicit name (only WORLD (default) # process group can have an explicit name), we use global _world.group_count @@ -1201,6 +1263,12 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): del _world.pg_names[pg] del _world.pg_group_ranks[pg] del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is destroyed. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] tag = _world.pg_to_tag.get(pg) del _world.pg_to_tag[pg] @@ -1412,47 +1480,99 @@ def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proces return src -class P2POp: - """ - A class to build point-to-point operations for ``batch_isend_irecv``. +class _IllegalWork(Work): + def __getattribute__(self, name): + if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]: + raise RuntimeError(f"Illegal to call {name} on IllegalWork object") - This class builds the type of P2P operation, communication buffer, peer rank, - Process Group group, and tag. Instances of this class will be passed to - ``batch_isend_irecv`` for point-to-point communications. - Args: - op (Callable): A function to send data to or receive data from a peer process. - The type of ``op`` is either ``torch.distributed.isend`` or - ``torch.distributed.irecv``. - tensor (Tensor): Tensor to send or receive. - peer (int): Destination or source rank. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - tag (int, optional): Tag to match send with recv. - """ +class _CoalescingManager: + def __init__(self): + self.works: List[Work] = [] - def __init__(self, op, tensor, peer, group=None, tag=0): - self.op = op - self.tensor = tensor - self.peer = peer - self.group = group - self.tag = tag + def append(self, work: Work): + if work: + self.works.append(work) - def __new__(cls, op, tensor, peer, group=None, tag=0): - _check_op(op) - _check_single_tensor(tensor, "tensor") - return object.__new__(cls) + def wait(self): + for work in self.works: + work.wait() @contextlib.contextmanager -def _coalescing_manager(group, device, reqs): - if group is None: - group = _get_default_group() - group._start_coalescing(device) +def _coalescing_manager( + group: Optional[ProcessGroup] = None, + device: Optional[torch.device] = None, + async_ops: Optional[bool] = False, +): + """ + A context manager used to coalesce collectives or P2P operations when possible. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`torch.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + async_ops (`bool`, optional): whether the coalesced ops are async ops. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _coalescing_manager(): + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # Asynchronous ops + >>> with _coalescing_manager(async_ops=True) as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> cm.wait() + + .. warning:: + :func:`_coalescing_manager` currently do not support coalescing + all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed + with `ReduceOp.PRODUCT`. + """ + group = group or _get_default_group() + op_list = _world.pg_coalesce_state.setdefault(group, []) + if op_list: + raise RuntimeError("ProcessGroup has non-empty op list at the start of coalescing") + if device: + group._start_coalescing(device) + cm = _CoalescingManager() try: - yield - finally: - group._end_coalescing(device, reqs) + yield cm + except Exception: + # Re-throw exception caught by code inside the context manager + raise + else: + op_list = _world.pg_coalesce_state.pop(group) + if op_list: + # Collectives supporting "Fast Path" coalescing are captured. + # See implementation in corresponding collective APIs. + # Currently supported: + # - allreduce_coalesced + op0 = op_list[0].op + if op0 == all_reduce: + tensors = [] + for op in op_list: + tensors.append(op.tensor) + opts = AllreduceCoalescedOptions() + opts.reduceOp = op_list[0].redop + work = group.allreduce_coalesced(tensors, opts) + else: + raise AssertionError( + f"Coalescing manager does not support fast-path coalescing of {op0}, " + f"yet {op0} is still recorded in op list. This is an internal error of c10d." + ) + + if device: + # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding + work = group._end_coalescing(device) + + if async_ops: + cm.append(work) + else: + work.wait() def batch_isend_irecv(p2p_op_list): @@ -1498,20 +1618,20 @@ def batch_isend_irecv(p2p_op_list): _check_p2p_op_list(p2p_op_list) group = p2p_op_list[0].group device = p2p_op_list[0].tensor.device - reqs = [] - with _coalescing_manager(group, device, reqs): + if device.type == "cuda": + # NCCL style coalescing + with _coalescing_manager(group, device, async_ops=True) as cm: + for p2p_op in p2p_op_list: + p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) + return cm.works + else: + # Backward support for Gloo + reqs = [] for p2p_op in p2p_op_list: - op = p2p_op.op - tensor = p2p_op.tensor - peer = p2p_op.peer - curr_group = p2p_op.group - tag = p2p_op.tag - - ret = op(tensor, peer, curr_group, tag) - - if ret is not None: - reqs.append(ret) - return reqs + work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) + if work: + reqs.append(work) + return reqs @exception_handler @@ -1739,10 +1859,18 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceOptions() opts.reduceOp = op if group is None: - default_pg = _get_default_group() - work = default_pg.allreduce([tensor], opts) - else: - work = group.allreduce([tensor], opts) + group = _get_default_group() + + if group in _world.pg_coalesce_state.keys(): + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_reduce, tensor, None, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group.allreduce([tensor], opts) if async_op: return work diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 4586553e9f69a..b9f2026def7e1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1248,6 +1248,75 @@ def test_3_level_hierarchical_model_averager(self): # No model averaging, so the parameters are not updated. self.assertEqual(param.data, tensor) + # Coalescing manager (sync mode) + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE, + "Coalescing manager currently tests with NCCL only; internal test flaky" + ) + def test_coalescing_manager(self): + self._barrier() + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + num_colls = 2 + size_per_coll = 8 + small_tensors = [ + torch.ones(size_per_coll, device=device_id) for _ in range(num_colls) + ] + + with dist._coalescing_manager(): + for i in range(num_colls): + dist.all_reduce(small_tensors[i]) + + big_tensor = torch.ones(num_colls * size_per_coll, device=device_id) + dist.all_reduce(big_tensor) + + for i in range(num_colls): + self.assertEqual( + small_tensors[i], + big_tensor[i * size_per_coll : (i + 1) * size_per_coll] + ) + + self._barrier() + + # Coalescing manager (async mode) + @skip_if_no_gpu + @skip_but_pass_in_sandcastle_if( + BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE, + "Coalescing manager currently tests with NCCL only; internal test flaky" + ) + def test_coalescing_manager_async(self): + self._barrier() + rank = dist.get_rank() + world_size = dist.get_world_size() + rank_to_GPU = init_multigpu_helper(world_size, BACKEND) + device_id = rank_to_GPU[rank][0] + torch.cuda.set_device(device_id) + num_colls = 2 + size_per_coll = 8 + small_tensors = [ + torch.ones(size_per_coll, device=device_id) for _ in range(num_colls) + ] + + with dist._coalescing_manager(async_ops=True) as cm: + for i in range(num_colls): + dist.all_reduce(small_tensors[i]) + cm.wait() + + big_tensor = torch.ones(num_colls * size_per_coll, device=device_id) + dist.all_reduce(big_tensor) + + for i in range(num_colls): + self.assertEqual( + small_tensors[i], + big_tensor[i * size_per_coll : (i + 1) * size_per_coll] + ) + + self._barrier() + # NCCL Batch SEND RECV @skip_if_no_gpu @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only") diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py index cacd1ff8ee58c..2f819729ebbb1 100644 --- a/torch/testing/_internal/distributed/multi_threaded_pg.py +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -1,7 +1,7 @@ import sys import threading from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from functools import partial, reduce import torch @@ -18,6 +18,7 @@ Store, ReduceOp, ) +from torch.distributed.distributed_c10d import _CollOp, P2POp from torch.futures import Future from torch.utils._pytree import tree_flatten @@ -371,13 +372,15 @@ class WorldData: group_count: int tags_to_pg: Dict[str, List[dist.ProcessGroup]] pg_to_tag: Dict[dist.ProcessGroup, str] + pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]] + class ThreadLocalWorld: _world = threading.local() def _get_world(self) -> WorldData: if not hasattr(ThreadLocalWorld._world, "world"): - ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}) + ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}) return ThreadLocalWorld._world.world @property @@ -420,6 +423,10 @@ def tags_to_pg(self): def pg_to_tag(self): return self._get_world().pg_to_tag + @property + def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]: + return self._get_world().pg_coalesce_state + _old_pg_world = None