Skip to content

Commit

Permalink
[c10d] Pass avoidRecordStreams into collective() function (pytorch#11…
Browse files Browse the repository at this point in the history
…2195)

Even after PR pytorch#111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.

As a fix, we pass `avoidRecordStreams` into the collective() function.

Pull Request resolved: pytorch#112195
Approved by: https://github.com/awgu
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Oct 28, 2023
1 parent 25f06ee commit a2dcf26
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
23 changes: 15 additions & 8 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1958,7 +1958,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre,
PostProcess post,
OpType opType,
const char* profilingTitle) {
const char* profilingTitle,
bool avoidRecordStreams) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
c10::cuda::CaptureStatus capture_status =
c10::cuda::currentStreamCaptureStatusMayInitCtx();
errorIfCapturingNonCapturableNCCL(capture_status);
Expand Down Expand Up @@ -2009,7 +2012,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);

if (avoidRecordStreams_) {
if (avoidRecordStreams) {
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>(inputs);
}
Expand Down Expand Up @@ -2052,7 +2055,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// operations where `inputs' and `outputs' are not the same.
//
// See [Sync Streams].
if (!avoidRecordStreams_) {
if (!avoidRecordStreams) {
if (!inputs[i].is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].storage().data_ptr(), ncclStream);
Expand Down Expand Up @@ -2111,7 +2114,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(

// Set appropriate work parameters.
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
// Record size info for debug. We only record the size on the first device as
Expand Down Expand Up @@ -2322,7 +2325,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
std::vector<at::Tensor>& outputs,
Fn fn,
OpType opType,
const char* profilingTitle) {
const char* profilingTitle,
bool avoidRecordStreams) {
return collective(
inputs,
outputs,
Expand All @@ -2332,7 +2336,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
[](std::vector<at::cuda::CUDAStream>&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
opType,
profilingTitle);
profilingTitle,
avoidRecordStreams);
}

template <typename Fn>
Expand Down Expand Up @@ -3053,7 +3058,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
stream.stream());
},
OpType::_REDUCE_SCATTER_BASE,
"nccl:_reduce_scatter_base");
"nccl:_reduce_scatter_base",
avoidRecordStreams);
}

c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
Expand Down Expand Up @@ -3737,7 +3743,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
stream.stream());
},
OpType::_ALLGATHER_BASE,
"nccl:_all_gather_base");
"nccl:_all_gather_base",
avoidRecordStreams);
}

#ifdef USE_NCCL_WITH_UCC
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Tensor>& output,
Fn fn,
OpType opType,
const char* profilingTitle = nullptr);
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false);

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
std::vector<at::Tensor>& input,
Expand All @@ -559,7 +561,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
const char* profilingTitle = nullptr);
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false);

// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
Expand Down
9 changes: 9 additions & 0 deletions torch/distributed/fsdp/_flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,15 @@ def _all_gather_flat_param(
sharded_flat_param,
self.process_group,
)

if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param

def _use_unsharded_flat_param(
Expand Down

0 comments on commit a2dcf26

Please sign in to comment.