Skip to content

Commit

Permalink
[c10d] Add flag value for direct teardown without comm abort (pytorch…
Browse files Browse the repository at this point in the history
…#102599)

It was recently reported that `ncclCommAbort` itself may hang in some NCCL versions. For example, NVIDIA/nccl#829.
In that case, it may be desirable to directly tear down the program without properly aborting the NCCL communicator, so that user does not wait for hours before noticing a hang.
This PR adds new value 3 for env `NCCL_ASYNC_ERROR_HANDLING` that skips the comm abort, and directly throws error in case of exception (timeout, async error, etc)
Pull Request resolved: pytorch#102599
Approved by: https://github.com/fegin
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Jun 2, 2023
1 parent 5be1088 commit 9fbfaaa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
14 changes: 8 additions & 6 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ void ProcessGroupNCCL::WorkNCCL::handleException(
LOG(ERROR) << exceptionMsg;
C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException");

if (errorHandling == TearDown) {
if (SHOULD_TEAR_DOWN(errorHandling)) {
auto tearDownMsg = c10::str(
"To avoid data inconsistency, we are taking the entire process down.");
LOG(ERROR) << tearDownMsg;
Expand Down Expand Up @@ -884,11 +884,13 @@ void ProcessGroupNCCL::workCleanupLoop() {

// If work hits an exception (either an error or timeout)
if (work.exception()) {
// Abort work and corresponding communicators
work.abort();
// PG level abort, which would abort all other communicators on this
// rank
abort();
if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
// Abort work and corresponding communicators
work.abort();
// PG level abort, which would abort all other communicators on this
// rank
abort();
}
// Report desync state in case of timeout
if (desyncDebug_ && timedOut) {
try {
Expand Down
14 changes: 11 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,17 @@ constexpr const char* NCCL_DESYNC_DEBUG = "NCCL_DESYNC_DEBUG";

constexpr const char* NCCL_BACKEND_NAME = "nccl";

// TearDown mode: tear down process upon error, see `WorkNCCL::handleException`
// Soft mode: just clean up collectives and abort communicators without tearing down process
enum ErrorHandlingMode { NoHandling = 0, TearDown = 1, CleanUpOnly = 2 };
// NoHandling: do not handle asynchronous NCCL errors
// TearDown: tear down process upon error, see `WorkNCCL::handleException`
// CleanUpOnly: just clean up collectives and abort communicators without tearing down process
// SkipCleanUp: (this is a temporary option and can be removed in future) tear
// down process without cleaning up NCCL communicators. This should be used as a
// last resort in case `ncclCommAbort` itself is hanging
enum ErrorHandlingMode { NoHandling = 0, TearDown = 1, CleanUpOnly = 2, SkipCleanUp = 3 };

#define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp)

#define SHOULD_TEAR_DOWN(a) (a != NoHandling && a != CleanUpOnly)

// If set, ProcessGroupNCCL doesn't use recordStream calls to ensure
// caching allocator safety for tensors used on both user-facing and
Expand Down

0 comments on commit 9fbfaaa

Please sign in to comment.