Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comms/torchcomms/nccl/TorchCommNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class TorchCommNCCL : public TorchCommBackend,
void timeoutWatchdog() noexcept;
void checkInitialized() const;
void checkAndAbortIfTimedOutOrError();
void checkWorkQueue(bool isMainThread);
void checkWorkQueue();
void enqueueWork(std::shared_ptr<TorchWorkNCCL> work, cudaStream_t stream);
bool getGraphCaptureMode();
cudaStream_t getOperationStream(bool async_op);
Expand Down
8 changes: 4 additions & 4 deletions comms/torchcomms/nccl/TorchCommNCCLUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ TorchCommNCCL::RedOpRAII TorchCommNCCL::getNcclReduceOp(
}
}

void TorchCommNCCL::checkWorkQueue(bool isMainThread) {
TorchWorkNCCL::WorkStatus status = workq_.garbageCollect(isMainThread);
void TorchCommNCCL::checkWorkQueue() {
TorchWorkNCCL::WorkStatus status = workq_.garbageCollect();

switch (status) {
case TorchWorkNCCL::WorkStatus::TIMEDOUT:
Expand Down Expand Up @@ -210,7 +210,7 @@ void TorchCommNCCL::timeoutWatchdog() noexcept {
}

// Check work objects for completion or timeout
checkWorkQueue(false);
checkWorkQueue();
if (comm_state_ != CommState::NORMAL &&
options_.abort_process_on_timeout_or_error) {
// Log the error and abort the process. We cannot abort the NCCL
Expand Down Expand Up @@ -243,7 +243,7 @@ void TorchCommNCCL::checkAndAbortIfTimedOutOrError() {
}

// First, check work queue status
checkWorkQueue(true);
checkWorkQueue();

if (comm_state_ == CommState::TIMEOUT) {
abortNcclComm();
Expand Down
3 changes: 1 addition & 2 deletions comms/torchcomms/nccl/TorchWorkNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,14 @@ class TorchWorkNCCLQueue {
TorchWorkNCCLQueue() = default;
~TorchWorkNCCLQueue() = default;

TorchWorkNCCL::WorkStatus garbageCollect(bool isMainThread);
TorchWorkNCCL::WorkStatus garbageCollect();
// Finalize function can only be called from the main thread
TorchWorkNCCL::WorkStatus finalize();
void enqueueWork(std::shared_ptr<TorchWorkNCCL> work, cudaStream_t stream);

private:
std::unordered_map<cudaStream_t, std::queue<std::shared_ptr<TorchWorkNCCL>>>
stream_work_queues_;
std::vector<std::shared_ptr<TorchWorkNCCL>> completed_work_queue_;
std::recursive_mutex work_queues_mutex_;
};

Expand Down
12 changes: 2 additions & 10 deletions comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
namespace torch {
namespace comms {

TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect(
bool isMainThread) {
TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect() {
std::lock_guard<std::recursive_mutex> lock(work_queues_mutex_);

TorchWorkNCCL::WorkStatus last_status = TorchWorkNCCL::WorkStatus::COMPLETED;
Expand All @@ -29,7 +28,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect(
if (status == TorchWorkNCCL::WorkStatus::COMPLETED) {
// Work is completed, remove it from the work queue
work_queue.pop();
completed_work_queue_.push_back(work);
// Continue to the next element in the queue
} else if (
status == TorchWorkNCCL::WorkStatus::TIMEDOUT ||
Expand All @@ -50,11 +48,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::garbageCollect(
}
}

if (isMainThread) {
// If we are the main thread, clear the completed work queues
completed_work_queue_.clear();
}

return last_status;
}

Expand All @@ -70,7 +63,7 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() {
// empty
TorchWorkNCCL::WorkStatus status = TorchWorkNCCL::WorkStatus::COMPLETED;
while (!stream_work_queues_.empty()) {
status = garbageCollect(true);
status = garbageCollect();
if (status == TorchWorkNCCL::WorkStatus::ERROR ||
status == TorchWorkNCCL::WorkStatus::TIMEDOUT ||
status == TorchWorkNCCL::WorkStatus::COMPLETED) {
Expand All @@ -83,7 +76,6 @@ TorchWorkNCCL::WorkStatus TorchWorkNCCLQueue::finalize() {
// NOTE: finalize MUST return without holding references to any work object,
// otherwise it may leak object and cause side effects.
stream_work_queues_.clear();
completed_work_queue_.clear();

return status;
}
Expand Down
Loading