diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index f8844515e9661f..df8c1bb6d9d613 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1958,7 +1958,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, - const char* profilingTitle) { + const char* profilingTitle, + bool avoidRecordStreams) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2009,7 +2012,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - if (avoidRecordStreams_) { + if (avoidRecordStreams) { work->stashed_for_allocator_safety_ = std::make_shared>(inputs); } @@ -2052,7 +2055,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. - if (!avoidRecordStreams_) { + if (!avoidRecordStreams) { if (!inputs[i].is_sparse()) { c10::cuda::CUDACachingAllocator::recordStream( inputs[i].storage().data_ptr(), ncclStream); @@ -2111,7 +2114,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams_; + work->avoidRecordStreams_ = avoidRecordStreams; work->opTimeout_ = options_->timeout; work->store_ = store_; // Record size info for debug. We only record the size on the first device as @@ -2322,7 +2325,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( std::vector& outputs, Fn fn, OpType opType, - const char* profilingTitle) { + const char* profilingTitle, + bool avoidRecordStreams) { return collective( inputs, outputs, @@ -2332,7 +2336,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( [](std::vector&, c10::intrusive_ptr& work) {}, opType, - profilingTitle); + profilingTitle, + avoidRecordStreams); } template @@ -3053,7 +3058,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( stream.stream()); }, OpType::_REDUCE_SCATTER_BASE, - "nccl:_reduce_scatter_base"); + "nccl:_reduce_scatter_base", + avoidRecordStreams); } c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( @@ -3737,7 +3743,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( stream.stream()); }, OpType::_ALLGATHER_BASE, - "nccl:_all_gather_base"); + "nccl:_all_gather_base", + avoidRecordStreams); } #ifdef USE_NCCL_WITH_UCC diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 0d468173b7a684..ac69725a4adef8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -550,7 +550,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::vector& output, Fn fn, OpType opType, - const char* profilingTitle = nullptr); + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false); + template c10::intrusive_ptr collective( std::vector& input, @@ -559,7 +561,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, - const char* profilingTitle = nullptr); + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index bf58586b96ae94..14dd7efd18869f 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -1341,6 +1341,15 @@ def _all_gather_flat_param( sharded_flat_param, self.process_group, ) + + if self._offload_params: + # In case of offloading, `flat_param.data` (i.e. sharded param) is + # created on the pre-unshard stream. We need to hand it over to the + # unshard stream for all-gather + _no_dispatch_record_stream( + sharded_flat_param, + self._device_handle.current_stream(), # unshard_stream + ) return padded_unsharded_flat_param def _use_unsharded_flat_param(