1+ #ifdef  USE_C10D_XCCL
2+ 
13#include  < torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp> 
24#include  < fstream> 
3- #include  < mutex> 
4- #include  < sstream> 
5- 
6- #ifdef  USE_C10D_XCCL
75#include  < comm/XPUGuard.h> 
86#include  < exception> 
97#include  < map> 
8+ #include  < sstream> 
109#include  < stdexcept> 
1110#include  < tuple> 
1211#include  < unordered_set> 
1312#include  < utility> 
1413
1514#include  < ATen/detail/FunctionTraits.h> 
1615#include  < c10/core/DeviceType.h> 
17- #include  < c10/util/CallOnce.h> 
18- #include  < c10/util/Exception.h> 
19- #include  < c10/util/Logging.h> 
2016#include  < c10/util/Optional.h> 
21- #include  < c10/util/irange.h> 
22- #include  < torch/csrc/distributed/c10d/ParamCommsUtils.hpp> 
23- #include  < torch/csrc/distributed/c10d/TraceUtils.h> 
24- #include  < torch/csrc/distributed/c10d/Utils.hpp> 
25- #include  < torch/torch.h> 
2617
2718namespace  c10d  {
2819
@@ -61,36 +52,6 @@ std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
6152    {at::kFloat8_e5m2fnuz , ccl::datatype::uint8},
6253};
6354
64- XCCL_KVS kvs;
65- std::mutex kvs_mutex;
66- 
67- XCCL_KVS get_kvs (int  rank, c10d::Store& store) {
68-   std::lock_guard<std::mutex> lock (kvs_mutex);
69-   if  (kvs)
70-     return  kvs;
71-   std::string storeKey = " xccl_kvs" 
72- 
73-   //  Rank 0 broadcast the bootstrap network information to other ranks
74-   if  (rank == 0 ) {
75-     kvs = ccl::create_main_kvs ();
76-     ccl::kvs::address_type main_addr = kvs->get_address ();
77-     auto  ccl_kvs_addr =
78-         std::vector<uint8_t >(main_addr.begin (), main_addr.end ());
79-     store.set (storeKey, ccl_kvs_addr);
80-   } else  {
81-     auto  ccl_kvs_addr = store.get (storeKey);
82-     if  (ccl_kvs_addr.size () != ccl::kvs::address_max_size) {
83-       throw  std::runtime_error (" Unexpected ccl kvs addr from the store\n " 
84-     }
85-     ccl::kvs::address_type main_addr;
86-     std::copy_n (
87-         ccl_kvs_addr.begin (), ccl::kvs::address_max_size, main_addr.begin ());
88-     kvs = ccl::create_kvs (main_addr);
89-   }
90- 
91-   return  kvs;
92- }
93- 
9455bool  check_same_size (const  std::vector<at::Tensor>& input_tensors) {
9556  for  (const  auto & input_tensor : input_tensors) {
9657    if  (!input_tensors[0 ].is_same_size (input_tensor)) {
@@ -159,23 +120,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
159120    }
160121    return  xcclOps.at (reduceOp);
161122  } catch  (const  std::out_of_range&) {
162-     switch  (reduceOp) {
163-       case  ReduceOp::AVG:
164-         C10_THROW_ERROR (ValueError, " Cannot use ReduceOp AVG with XCCL" 
165-         break ;
166-       case  ReduceOp::BAND:
167-         C10_THROW_ERROR (ValueError, " Cannot use ReduceOp.BAND with XCCL" 
168-         break ;
169-       case  ReduceOp::BOR:
170-         C10_THROW_ERROR (ValueError, " Cannot use ReduceOp.BOR with XCCL" 
171-         break ;
172-       case  ReduceOp::BXOR:
173-         C10_THROW_ERROR (ValueError, " Cannot use ReduceOp.BXOR with XCCL" 
174-         break ;
175-       default :
176-         C10_THROW_ERROR (ValueError, " Unhandled ReduceOp" 
177-         break ;
178-     }
123+     C10_THROW_ERROR (
124+         ValueError,
125+         " Cannot use ReduceOp." reduce_op_to_string (reduceOp) + "  with XCCL" 
179126  }
180127}
181128
@@ -210,20 +157,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
210157
211158ProcessGroupXCCL::WorkXCCL::~WorkXCCL () = default ;
212159
213- bool  ProcessGroupXCCL::WorkXCCL::checkTimeout (
214-     std::optional<std::chrono::milliseconds> timeout) {
215-   auto  currentTimepoint = std::chrono::steady_clock::now ();
216-   auto  timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
217-       currentTimepoint - workStartTime_);
218-   std::chrono::milliseconds opTimeout = std::chrono::milliseconds (60000 );
219- 
220-   auto  workTimeout = timeout ? *timeout : opTimeout;
221- 
222-   if  (timeElapsed < workTimeout)
223-     return  false ;
224-   return  true ;
225- }
226- 
227160bool  ProcessGroupXCCL::WorkXCCL::isCompleted () {
228161  if  (xcclEndEvent_ && xcclEndEvent_->query ()) {
229162    return  true ;
@@ -235,23 +168,23 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() {
235168  synchronizeInternal (kNoTimeout );
236169}
237170
238- void  ProcessGroupXCCL::WorkXCCL::synchronizeStream () {
239-   auto  currentStream = at::xpu::getCurrentXPUStream (device_.index ());
240-   //  Block the current stream on the XCCL stream
241-   xcclEndEvent_->block (currentStream);
242- }
243- 
244171void  ProcessGroupXCCL::WorkXCCL::synchronizeInternal (
245172    std::chrono::milliseconds timeout) {
246-   synchronizeStream ( );
247- 
173+   auto  currentStream =  at::xpu::getCurrentXPUStream (device_. index () );
174+   xcclEndEvent_-> block (currentStream); 
248175  if  (blockingWait_) {
249176    while  (!isCompleted ()) {
250-       bool  timedOut = checkTimeout (
251-           timeout == kNoTimeout  ? std::nullopt  : std::make_optional (timeout));
252-       if  (timedOut) {
253-         break ;
177+       auto  currentTimepoint = std::chrono::steady_clock::now ();
178+       auto  timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
179+           currentTimepoint - workStartTime_);
180+       if  (timeElapsed >= timeout) {
181+         std::string exceptionMsg = c10::str (
182+             " Work ran for " 
183+             timeElapsed.count (),
184+             "  milliseconds before timing out." 
185+         TORCH_CHECK (false , exceptionMsg)
254186      }
187+ 
255188      std::this_thread::sleep_for (
256189          std::chrono::milliseconds (kSynchronizeBusyWaitMillis ));
257190    }
0 commit comments