Skip to content

Commit 4a45d29

Browse files
committed
Merge branch 'xccl-bak' into xccl-group
2 parents 4f4ecf4 + e621fe6 commit 4a45d29

File tree

9 files changed

+130
-143
lines changed

9 files changed

+130
-143
lines changed

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@ cmake_dependent_option(
369369
USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF)
370370
cmake_dependent_option(
371371
USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF)
372-
cmake_dependent_option(
373-
USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF)
374372
cmake_dependent_option(
375373
USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF)
376374
cmake_dependent_option(

caffe2/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,10 @@ elseif(USE_CUDA)
10131013
endif()
10141014

10151015
if(USE_XPU)
1016+
# if SYCL runtime and oneCCL runtime are both system installed
1017+
# then building flag USE_XPU=ON , USE_XCCL=ON and USE_C10D_XCCL=ON;
1018+
# XCCL backend will be build in libtorch_xpu;
1019+
# manually set `USE_XCCL=OFF` disable XCCL backend building.
10161020
if(USE_XCCL)
10171021
append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS)
10181022
endif()
@@ -1370,7 +1374,7 @@ if(USE_DISTRIBUTED)
13701374
target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL)
13711375
endif()
13721376
endif()
1373-
if(USE_C10D_XCCL)
1377+
if(USE_XPU AND USE_C10D_XCCL)
13741378
target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
13751379
set_source_files_properties(
13761380
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp

cmake/Dependencies.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,9 @@ if(USE_XCCL)
11631163
caffe2_update_option(USE_XCCL OFF)
11641164
else()
11651165
include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake)
1166-
list(APPEND Caffe2_XPU_DEPENDENCY_LIBS torch::xccl)
1166+
if(NOT XCCL_FOUND)
1167+
caffe2_update_option(USE_XCCL OFF)
1168+
endif()
11671169
endif()
11681170
endif()
11691171

cmake/External/xccl.cmake

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
if(NOT __XCCL_INCLUDED)
22
set(__XCCL_INCLUDED TRUE)
33

4-
if(USE_XCCL)
5-
# XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake.
6-
find_package(XCCL REQUIRED)
7-
if(XCCL_FOUND)
8-
add_library(torch::xccl INTERFACE IMPORTED)
9-
set_property(
10-
TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES
11-
${XCCL_INCLUDE_DIR})
12-
set_property(
13-
TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES
14-
${XCCL_LIBRARY})
15-
endif()
4+
# XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake.
5+
find_package(XCCL REQUIRED)
6+
if(XCCL_FOUND)
7+
add_library(torch::xccl INTERFACE IMPORTED)
8+
set_property(
9+
TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES
10+
${XCCL_INCLUDE_DIR})
11+
set_property(
12+
TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES
13+
${XCCL_LIBRARY})
1614
endif()
1715
endif()

cmake/Modules/FindXCCL.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ if(DEFINED ENV{CCL_ROOT})
1111
set(XCCL_ROOT $ENV{CCL_ROOT})
1212
endif()
1313

14-
string(COMPARE EQUAL "${XCCL_ROOT}" "" nosyclfound)
15-
if(nosyclfound)
14+
string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound)
15+
if(nocclfound)
1616
set(XCCL_FOUND False)
17-
set(XCCL_REASON_FAILURE "XCCL library not set!!")
17+
set(XCCL_REASON_FAILURE "OneCCL library not found!!")
1818
set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}")
1919
return()
2020
endif()
@@ -55,7 +55,7 @@ find_library(
5555

5656
if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY))
5757
set(XCCL_FOUND False)
58-
set(XCCL_REASON_FAILURE "XCCL library is incomplete!!")
58+
set(XCCL_REASON_FAILURE "OneCCL library not found!!")
5959
set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}")
6060
return()
6161
endif()

test/distributed/test_c10d_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,9 @@ def test_init_process_group_for_all_backends(self):
18361836
elif backend == dist.Backend.UCC:
18371837
if not dist.is_ucc_available():
18381838
continue
1839+
elif backend == dist.Backend.XCCL:
1840+
if not dist.is_xccl_available():
1841+
continue
18391842
# Multi-threaded PG is defined as a pure python class.
18401843
# Its pg.name() does not going through Pybind, so its backend name
18411844
# is still "threaded" instead of "custom".

torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp

Lines changed: 18 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
1+
#ifdef USE_C10D_XCCL
2+
13
#include <torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp>
24
#include <fstream>
3-
#include <mutex>
4-
#include <sstream>
5-
6-
#ifdef USE_C10D_XCCL
75
#include <comm/XPUGuard.h>
86
#include <exception>
97
#include <map>
8+
#include <sstream>
109
#include <stdexcept>
1110
#include <tuple>
1211
#include <unordered_set>
1312
#include <utility>
1413

1514
#include <ATen/detail/FunctionTraits.h>
1615
#include <c10/core/DeviceType.h>
17-
#include <c10/util/CallOnce.h>
18-
#include <c10/util/Exception.h>
19-
#include <c10/util/Logging.h>
2016
#include <c10/util/Optional.h>
21-
#include <c10/util/irange.h>
22-
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
23-
#include <torch/csrc/distributed/c10d/TraceUtils.h>
24-
#include <torch/csrc/distributed/c10d/Utils.hpp>
25-
#include <torch/torch.h>
2617

2718
namespace c10d {
2819

@@ -61,36 +52,6 @@ std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
6152
{at::kFloat8_e5m2fnuz, ccl::datatype::uint8},
6253
};
6354

64-
XCCL_KVS kvs;
65-
std::mutex kvs_mutex;
66-
67-
XCCL_KVS get_kvs(int rank, c10d::Store& store) {
68-
std::lock_guard<std::mutex> lock(kvs_mutex);
69-
if (kvs)
70-
return kvs;
71-
std::string storeKey = "xccl_kvs";
72-
73-
// Rank 0 broadcast the bootstrap network information to other ranks
74-
if (rank == 0) {
75-
kvs = ccl::create_main_kvs();
76-
ccl::kvs::address_type main_addr = kvs->get_address();
77-
auto ccl_kvs_addr =
78-
std::vector<uint8_t>(main_addr.begin(), main_addr.end());
79-
store.set(storeKey, ccl_kvs_addr);
80-
} else {
81-
auto ccl_kvs_addr = store.get(storeKey);
82-
if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) {
83-
throw std::runtime_error("Unexpected ccl kvs addr from the store\n");
84-
}
85-
ccl::kvs::address_type main_addr;
86-
std::copy_n(
87-
ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin());
88-
kvs = ccl::create_kvs(main_addr);
89-
}
90-
91-
return kvs;
92-
}
93-
9455
bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
9556
for (const auto& input_tensor : input_tensors) {
9657
if (!input_tensors[0].is_same_size(input_tensor)) {
@@ -159,23 +120,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
159120
}
160121
return xcclOps.at(reduceOp);
161122
} catch (const std::out_of_range&) {
162-
switch (reduceOp) {
163-
case ReduceOp::AVG:
164-
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL");
165-
break;
166-
case ReduceOp::BAND:
167-
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL");
168-
break;
169-
case ReduceOp::BOR:
170-
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL");
171-
break;
172-
case ReduceOp::BXOR:
173-
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL");
174-
break;
175-
default:
176-
C10_THROW_ERROR(ValueError, "Unhandled ReduceOp");
177-
break;
178-
}
123+
C10_THROW_ERROR(
124+
ValueError,
125+
"Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL");
179126
}
180127
}
181128

@@ -210,20 +157,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
210157

211158
ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default;
212159

213-
bool ProcessGroupXCCL::WorkXCCL::checkTimeout(
214-
std::optional<std::chrono::milliseconds> timeout) {
215-
auto currentTimepoint = std::chrono::steady_clock::now();
216-
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
217-
currentTimepoint - workStartTime_);
218-
std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000);
219-
220-
auto workTimeout = timeout ? *timeout : opTimeout;
221-
222-
if (timeElapsed < workTimeout)
223-
return false;
224-
return true;
225-
}
226-
227160
bool ProcessGroupXCCL::WorkXCCL::isCompleted() {
228161
if (xcclEndEvent_ && xcclEndEvent_->query()) {
229162
return true;
@@ -235,23 +168,23 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() {
235168
synchronizeInternal(kNoTimeout);
236169
}
237170

238-
void ProcessGroupXCCL::WorkXCCL::synchronizeStream() {
239-
auto currentStream = at::xpu::getCurrentXPUStream(device_.index());
240-
// Block the current stream on the XCCL stream
241-
xcclEndEvent_->block(currentStream);
242-
}
243-
244171
void ProcessGroupXCCL::WorkXCCL::synchronizeInternal(
245172
std::chrono::milliseconds timeout) {
246-
synchronizeStream();
247-
173+
auto currentStream = at::xpu::getCurrentXPUStream(device_.index());
174+
xcclEndEvent_->block(currentStream);
248175
if (blockingWait_) {
249176
while (!isCompleted()) {
250-
bool timedOut = checkTimeout(
251-
timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout));
252-
if (timedOut) {
253-
break;
177+
auto currentTimepoint = std::chrono::steady_clock::now();
178+
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
179+
currentTimepoint - workStartTime_);
180+
if (timeElapsed >= timeout) {
181+
std::string exceptionMsg = c10::str(
182+
"Work ran for ",
183+
timeElapsed.count(),
184+
" milliseconds before timing out.");
185+
TORCH_CHECK(false, exceptionMsg)
254186
}
187+
255188
std::this_thread::sleep_for(
256189
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
257190
}

torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp

Lines changed: 62 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,9 @@
2828
#include <c10/xpu/XPUCachingAllocator.h>
2929
#include <torch/csrc/distributed/c10d/Backend.hpp>
3030
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
31-
#include <torch/csrc/distributed/c10d/Store.hpp>
3231
namespace c10d {
3332

3433
namespace {
35-
int getXCCLEnvVar(std::string envVarName) {
36-
char* stringValue = std::getenv(envVarName.c_str());
37-
if (stringValue != nullptr) {
38-
try {
39-
int val = std::stoi(stringValue);
40-
return val;
41-
} catch (std::exception& e) {
42-
TORCH_CHECK(
43-
false,
44-
"Invalid value for environment variable: " + std::string(envVarName));
45-
}
46-
} else {
47-
return -1;
48-
}
49-
}
50-
51-
template <typename T>
52-
void setXCCLEnvVar(const std::string& envVarName, T val) {
53-
if constexpr (std::is_same_v<T, int>) {
54-
setenv(envVarName.c_str(), std::to_string(val).c_str(), 1);
55-
} else if constexpr (std::is_same_v<T, std::string>) {
56-
setenv(envVarName.c_str(), val.c_str(), 1);
57-
}
58-
}
59-
60-
bool with_mpirun() {
61-
return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") ||
62-
getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK"))
63-
? true
64-
: false;
65-
}
66-
6734
struct AutoXcclGroup {
6835
AutoXcclGroup();
6936
~AutoXcclGroup() noexcept(false);
@@ -103,8 +70,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
10370

10471
void synchronize() override;
10572

106-
void synchronizeStream();
107-
10873
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
10974

11075
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
@@ -115,9 +80,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
11580
return *outputs_;
11681
}
11782

118-
bool checkTimeout(
119-
std::optional<std::chrono::milliseconds> timeout = std::nullopt);
120-
12183
protected:
12284
at::Device device_;
12385
std::shared_ptr<at::xpu::XPUEvent> xcclEndEvent_;
@@ -330,7 +292,69 @@ class TORCH_API ProcessGroupXCCL : public Backend {
330292
std::shared_ptr<xcclComm_t> coalescedComm_ = nullptr;
331293
bool blockingWait_ = false;
332294
static thread_local uint64_t xcclActiveGroupCounter_;
295+
private:
296+
XCCL_KVS kvs;
297+
std::mutex kvs_mutex;
298+
XCCL_KVS get_kvs(int rank, c10d::Store& store) {
299+
std::lock_guard<std::mutex> lock(kvs_mutex);
300+
if (kvs)
301+
return kvs;
302+
std::string storeKey = "xccl_kvs";
303+
// Rank 0 broadcast the bootstrap network information to other ranks
304+
if (rank == 0) {
305+
kvs = ccl::create_main_kvs();
306+
ccl::kvs::address_type main_addr = kvs->get_address();
307+
auto ccl_kvs_addr =
308+
std::vector<uint8_t>(main_addr.begin(), main_addr.end());
309+
store.set(storeKey, ccl_kvs_addr);
310+
} else {
311+
auto ccl_kvs_addr = store.get(storeKey);
312+
if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) {
313+
throw std::runtime_error("Unexpected ccl kvs addr from the store\n");
314+
}
315+
ccl::kvs::address_type main_addr;
316+
std::copy_n(
317+
ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin());
318+
kvs = ccl::create_kvs(main_addr);
319+
}
320+
return kvs;
321+
}
333322
};
323+
324+
namespace {
325+
int getXCCLEnvVar(std::string envVarName) {
326+
char* stringValue = std::getenv(envVarName.c_str());
327+
if (stringValue != nullptr) {
328+
try {
329+
int val = std::stoi(stringValue);
330+
return val;
331+
} catch (std::exception& e) {
332+
TORCH_CHECK(
333+
false,
334+
"Invalid value for environment variable: " + std::string(envVarName));
335+
}
336+
} else {
337+
return -1;
338+
}
339+
}
340+
341+
template <typename T>
342+
void setXCCLEnvVar(const std::string& envVarName, T val) {
343+
if constexpr (std::is_same_v<T, int>) {
344+
setenv(envVarName.c_str(), std::to_string(val).c_str(), 1);
345+
} else if constexpr (std::is_same_v<T, std::string>) {
346+
setenv(envVarName.c_str(), val.c_str(), 1);
347+
}
348+
}
349+
350+
bool with_mpirun() {
351+
return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") ||
352+
getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK"))
353+
? true
354+
: false;
355+
}
356+
357+
} // namespace
334358
} // namespace c10d
335359

336360
#endif // USE_C10D_XCCL

0 commit comments

Comments
 (0)