Skip to content

Commit 5c22a0a

Browse files
tanquermeta-codesync[bot]
authored andcommitted
Remove completed_queue for NCCL, MCCL & HCCL
Summary: Similar to D85455174, remove completed_queue for other CCLs. Context: The original purpose was to prevent destruction outside the main thread of the work object. However now that we have a finalize method, we can rely on the contract that the user will hold onto the object until finalize is called. Reviewed By: siyengar Differential Revision: D85771546 fbshipit-source-id: f28e89b8c1fe769f6a144f9684ac4b39f4a44460
1 parent 477b433 commit 5c22a0a

File tree

4 files changed

+8
-17
lines changed

4 files changed

+8
-17
lines changed

comms/torchcomms/nccl/TorchCommNCCL.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ class TorchCommNCCL : public TorchCommBackend,
299299
void timeoutWatchdog() noexcept;
300300
void checkInitialized() const;
301301
void checkAndAbortIfTimedOutOrError();
302-
void checkWorkQueue(bool isMainThread);
302+
void checkWorkQueue();
303303
void enqueueWork(std::shared_ptr<TorchWorkNCCL> work, cudaStream_t stream);
304304
bool getGraphCaptureMode();
305305
cudaStream_t getOperationStream(bool async_op);

comms/torchcomms/nccl/TorchCommNCCLUtils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ TorchCommNCCL::RedOpRAII TorchCommNCCL::getNcclReduceOp(
175175
}
176176
}
177177

178-
void TorchCommNCCL::checkWorkQueue(bool isMainThread) {
179-
TorchWorkNCCL::WorkStatus status = workq_.garbageCollect(isMainThread);
178+
void TorchCommNCCL::checkWorkQueue() {
179+
TorchWorkNCCL::WorkStatus status = workq_.garbageCollect();
180180

181181
switch (status) {
182182
case TorchWorkNCCL::WorkStatus::TIMEDOUT:
@@ -210,7 +210,7 @@ void TorchCommNCCL::timeoutWatchdog() noexcept {
210210
}
211211

212212
// Check work objects for completion or timeout
213-
checkWorkQueue(false);
213+
checkWorkQueue();
214214
if (comm_state_ != CommState::NORMAL &&
215215
options_.abort_process_on_timeout_or_error) {
216216
// Log the error and abort the process. We cannot abort the NCCL
@@ -243,7 +243,7 @@ void TorchCommNCCL::checkAndAbortIfTimedOutOrError() {
243243
}
244244

245245
// First, check work queue status
246-
checkWorkQueue(true);
246+
checkWorkQueue();
247247

248248
if (comm_state_ == CommState::TIMEOUT) {
249249
abortNcclComm();

comms/torchcomms/nccl/TorchWorkNCCL.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,14 @@ class TorchWorkNCCLQueue {
9191
TorchWorkNCCLQueue() = default;
9292
~TorchWorkNCCLQueue() = default;
9393

94-
TorchWorkNCCL::WorkStatus garbageCollect(bool isMainThread);
94+
TorchWorkNCCL::WorkStatus garbageCollect();
9595
// Finalize function can only be called from the main thread
9696
TorchWorkNCCL::WorkStatus finalize();
9797
void enqueueWork(std::shared_ptr<TorchWorkNCCL> work, cudaStream_t stream);
9898

9999
private:
100100
std::unordered_map<cudaStream_t, std::queue<std::shared_ptr<TorchWorkNCCL>>>
101101
stream_work_queues_;
102-
std::vector<std::shared_ptr<TorchWorkNCCL>> completed_work_queue_;
103102
std::recursive_mutex work_queues_mutex_;
104103
};
105104

comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
namespace torch {
66
namespace comms {
77

8-
TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect(
9-
bool isMainThread) {
8+
TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect() {
109
std::lock_guard<std::recursive_mutex> lock(work_queues_mutex_);
1110

1211
TorchWorkNCCL::WorkStatus last_status = TorchWorkNCCL::WorkStatus::COMPLETED;
@@ -29,7 +28,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect(
2928
if (status == TorchWorkNCCL::WorkStatus::COMPLETED) {
3029
// Work is completed, remove it from the work queue
3130
work_queue.pop();
32-
completed_work_queue_.push_back(work);
3331
// Continue to the next element in the queue
3432
} else if (
3533
status == TorchWorkNCCL::WorkStatus::TIMEDOUT ||
@@ -50,11 +48,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect(
5048
}
5149
}
5250

53-
if (isMainThread) {
54-
// If we are the main thread, clear the completed work queues
55-
completed_work_queue_.clear();
56-
}
57-
5851
return last_status;
5952
}
6053

@@ -70,7 +63,7 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() {
7063
// empty
7164
TorchWorkNCCL::WorkStatus status = TorchWorkNCCL::WorkStatus::COMPLETED;
7265
while (!stream_work_queues_.empty()) {
73-
status = garbageCollect(true);
66+
status = garbageCollect();
7467
if (status == TorchWorkNCCL::WorkStatus::ERROR ||
7568
status == TorchWorkNCCL::WorkStatus::TIMEDOUT ||
7669
status == TorchWorkNCCL::WorkStatus::COMPLETED) {
@@ -83,7 +76,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() {
8376
// NOTE: finalize MUST return without holding references to any work object,
8477
// otherwise it may leak object and cause side effects.
8578
stream_work_queues_.clear();
86-
completed_work_queue_.clear();
8779

8880
return status;
8981
}

0 commit comments

Comments
 (0)