Skip to content

Commit

Permalink
[c10d] Faster coalescing (pytorch#98793)
Browse files Browse the repository at this point in the history
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.

By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.

In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.

### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us

Hence a 4x reduction in CPU overhead (dependent on `num_coll`).

Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: pytorch#98793
Approved by: https://github.com/kumpera
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Apr 24, 2023
1 parent 3dcc7b3 commit 3a09aa5
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 111 deletions.
11 changes: 7 additions & 4 deletions torch/csrc/distributed/c10d/Backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ class TORCH_API Backend : public torch::CustomClassHolder {
}

virtual void startCoalescing() {
// no-op for backends that have not implemented startCoalescing
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not implement startCoalescing"));
}

virtual void endCoalescing(
std::vector<c10::intrusive_ptr<Work>>& /* reqs */) {
// no-op for backends that have not implemented endCoalescing
virtual c10::intrusive_ptr<Work> endCoalescing() {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not implement endCoalescing"));
}

// Subclasses must override this method to return the backend name
Expand Down
17 changes: 7 additions & 10 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,16 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
virtual void startCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented startCoalescing so only execute for nccl
// backends
if (getBackendType() == BackendType::NCCL) {
getBackend(deviceType)->startCoalescing();
}
auto backend = getBackend(deviceType);
backend->startCoalescing();
}

virtual void endCoalescing(
c10::DeviceType deviceType,
std::vector<c10::intrusive_ptr<Work>>& reqs) {
// only nccl has implemented startCoalescing so only execute for nccl
virtual c10::intrusive_ptr<Work> endCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented endCoalescing so only execute for nccl
// backends
if (getBackendType() == BackendType::NCCL) {
getBackend(deviceType)->endCoalescing(reqs);
}
auto backend = getBackend(deviceType);
auto work = backend->endCoalescing();
return work;
}

virtual c10::intrusive_ptr<Work> broadcast(
Expand Down
96 changes: 60 additions & 36 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1389,44 +1389,62 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream)
: Backend::Options(NCCL_BACKEND_NAME),
is_high_priority_stream(is_high_priority_stream) {}

static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;

void ProcessGroupNCCL::startCoalescing() {
coalescedDevices_.clear();
coalescing_active_ = true;
coalescedComms_.clear();
coalescing_state_ |= CoalActive;
groupStart();
}

void ProcessGroupNCCL::endCoalescing(
std::vector<c10::intrusive_ptr<Work>>& reqs) {
if (!nccl_use_nonblocking()) {
c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
if (!nccl_use_nonblocking() ||
coalescedComms_.size() == 0) { // There is no actual work being coalesced
groupEnd();
} else {
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
for (const auto& req : reqs) {
auto ncclWork = static_cast<ProcessGroupNCCL::WorkNCCL*>(req.get());
ncclComms_.insert(
ncclComms_.end(),
ncclWork->ncclComms_.begin(),
ncclWork->ncclComms_.end());
}
groupEndNonblocking(ncclComms_);
// `coalescedComms_` should have same set of comms across collectives
auto comms = coalescedComms_[0];
groupEndNonblocking(comms);
}
if (reqs.size() != coalescedDevices_.size()) {
TORCH_CHECK(false, "Number of requests do not match number of collectives");

coalescing_state_ = 0;

if (coalescedDevices_.size() == 0) {
// There is no actual work being coalesced
return nullptr;
}

int batch_idx = 0;
for (const auto& req : reqs) {
auto ncclWork = static_cast<ProcessGroupNCCL::WorkNCCL*>(req.get());
// @lint-ignore CLANGTIDY
std::vector<at::Device> devices = coalescedDevices_[batch_idx];
const auto key = getKeyFromDevices(devices);
auto& ncclStreams = ncclStreams_[key];
for (const auto i : c10::irange(devices.size())) {
(*ncclWork->ncclEndEvents_)[i].record(ncclStreams[i]);
}
batch_idx += 1;
// `coalescedDevices_` should have same set of devices across collectives
auto devices = coalescedDevices_[0];

// Create Work object
auto work = initWork(
devices, rank_, OpType::COALESCED, "nccl:coalesced", c10::nullopt);

// Record stream event
// `getKeyFromDevices` is how we get keys for both collectives and batch P2P
const auto key = getKeyFromDevices(devices);
auto& ncclStreams = ncclStreams_[key];
for (const auto i : c10::irange(devices.size())) {
auto& devEvent = (*work->ncclEndEvents_)[i];
devEvent.record(ncclStreams[i]);
}
coalescing_active_ = false;

// Set appropriate work parameters.
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;

if (coalescing_state_ & CoalColl) {
workEnqueue(work);
// TODO: it seems we never enqueue work for single send/recv or batch P2P,
// see the `pointToPoint` function. This should be fixed. Otherwise, we risk
// not being able to abort hanged P2P ops.
}

return work;
}

template <typename Fn, typename PreProcess, typename PostProcess>
Expand Down Expand Up @@ -1460,8 +1478,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
const auto key = getKeyFromDevices(devices);
auto& ncclComms = getNCCLComm(key, devices, opType);

if (coalescing_active_) {
if (coalescing_state_ & CoalActive) {
coalescing_state_ |= CoalColl;
coalescedDevices_.push_back(devices);
coalescedComms_.push_back(ncclComms);
}

// Used many times below, so we stash the unordered_map lookup
Expand Down Expand Up @@ -1547,7 +1567,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// End event should only be recorded after the ncclGroupEnd()
for (const auto i : c10::irange(devices.size())) {
at::cuda::CUDAStream& ncclStream = ncclStreams[i];
if (!coalescing_active_) {
if (!coalescing_state_) {
(*work->ncclEndEvents_)[i].record(ncclStream);
}
work->ncclComms_[i] = ncclComms[i];
Expand Down Expand Up @@ -1575,7 +1595,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
work->opTimeout_ = options_->timeout;
work->store_ = store_;

workEnqueue(work);
if (!coalescing_state_) {
workEnqueue(work);
}

return work;
}
Expand Down Expand Up @@ -1623,8 +1645,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
}
auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank, isSendRecvSelf);

if (coalescing_active_) {
if (coalescing_state_ & CoalActive) {
coalescing_state_ |= CoalP2P;
coalescedDevices_.push_back(devices);
coalescedComms_.push_back(ncclComms);
}

// First let NCCL streams wait for input tensors allocation streams
Expand Down Expand Up @@ -1707,7 +1731,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// End event should only be recorded after the ncclGroupEnd()
for (const auto i : c10::irange(tensors.size())) {
at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
if (!coalescing_active_) {
if (!coalescing_state_) {
(*work->ncclEndEvents_)[i].record(ncclStream);
}
work->ncclComms_[i] = ncclComms[i];
Expand Down Expand Up @@ -2165,8 +2189,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
_broadcast_oop(outputs_multi_dev, inputs_multi_dev, broadcastOpts);
works.push_back(work);
}
endCoalescing(works);
return initCoalescedWork(works, rank_, OpType::BROADCAST);
auto work = endCoalescing();
return work;
}
}

Expand Down Expand Up @@ -2290,8 +2314,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
auto work = _reduce_oop(outputs_multi_dev, inputs_multi_dev, reduceOpts);
works.push_back(work);
}
endCoalescing(works);
return initCoalescedWork(works, rank_, OpType::REDUCE);
auto work = endCoalescing();
return work;
}
}

Expand Down
8 changes: 5 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {

void startCoalescing() override;

void endCoalescing(
std::vector<c10::intrusive_ptr<Work>>& reqs) override;
c10::intrusive_ptr<Work> endCoalescing() override;

c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
Expand Down Expand Up @@ -649,11 +648,14 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::set<int> usedDeviceIdxs_;

// Flag to denote if a coalescing groupStart/groupEnd block is active
bool coalescing_active_ = false;
int coalescing_state_ = 0;

// Stores device indexes for all collectives run inside a coalescing block
std::vector<std::vector<at::Device>> coalescedDevices_;

// Stores communicators for all collectives run inside a coalescing block
std::vector<std::vector<std::shared_ptr<NCCLComm>>> coalescedComms_;

// map from the key: "group name + pg counter (ID)" to the
// unique NCCL ID count. This needs to be group and pg specific
//
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/Work.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum class OpType : std::uint8_t {
RECVANYSOURCE = 14,
BARRIER = 15,
_REDUCE_SCATTER_BASE = 16,
COALESCED = 17,
UNKNOWN = 100,
};

Expand Down
7 changes: 2 additions & 5 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1467,12 +1467,10 @@ that adds a prefix to each key inserted to the store.
.def(
"_end_coalescing",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
const c10::Device& device,
std::vector<c10::intrusive_ptr<::c10d::Work>>& reqs) {
self->endCoalescing(device.type(), reqs);
const c10::Device& device) {
return self->endCoalescing(device.type());
},
py::arg("device_type"),
py::arg("reqs"),
py::call_guard<py::gil_scoped_release>())
.def(
"_register_backend",
Expand Down Expand Up @@ -1832,7 +1830,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def(
"_end_coalescing",
&::c10d::Backend::endCoalescing,
py::arg("reqs"),
py::call_guard<py::gil_scoped_release>());

#ifdef USE_C10D_GLOO
Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def is_available() -> bool:
_create_process_group_wrapper,
_rank_not_in_group,
_c10d_error_logger,
_coalescing_manager,
_CoalescingManager,
)

from .rendezvous import (
Expand Down
Loading

0 comments on commit 3a09aa5

Please sign in to comment.