2121#include " paddle/fluid/imperative/parallel_context.h"
2222#include " paddle/fluid/operators/math/concat_and_split.h"
2323#include " paddle/fluid/operators/strided_memcpy.h"
24+ #ifdef PADDLE_WITH_XPU_BKCL
25+ #include " paddle/fluid/platform/device/xpu/enforce_xpu.h"
26+ #endif
2427#include " paddle/fluid/string/string_helper.h"
2528#include " paddle/phi/core/dense_tensor.h"
2629namespace paddle {
@@ -431,10 +434,6 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
431434 VLOG (3 ) << " Start construct the Reducer ..." ;
432435 nrings_ = parallel_ctx->GetNRings ();
433436 nranks_ = parallel_ctx->GetNRanks ();
434- #ifdef PADDLE_WITH_XPU_BKCL
435- comm_pool_.reset (new ::ThreadPool (1 ));
436- comm_op_count_ = 0 ;
437- #endif
438437 // initialize groups
439438 InitializeGroups (group_indices);
440439 for (size_t global_var_index = 0 ; global_var_index < vars_.size ();
@@ -853,8 +852,23 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
853852
854853#ifdef PADDLE_WITH_XPU_BKCL
855854 if (platform::is_xpu_place (group_tensor.place ())) {
856- // TODO(liuyuhui) support XPU set constant
857- VLOG (3 ) << " XPU doesn't support set_constant" ;
855+ auto dev_ctx = static_cast <platform::XPUDeviceContext *>(
856+ platform::DeviceContextPool::Instance ().Get (place_));
857+ if (HasGrad (var_index)) {
858+ auto var_base = vars_[var_index]->GradVarBase ();
859+ auto tensor =
860+ var_base->MutableVar ()->GetMutable <framework::LoDTensor>();
861+ group_tensor.ShareDataWith (*tensor).Resize (
862+ {static_cast <int64_t >(length)});
863+ } else {
864+ group_tensor.Resize ({static_cast <int64_t >(length)});
865+ int r = xpu::constant (dev_ctx->x_context (),
866+ reinterpret_cast <float *>(group_tensor.data ()),
867+ group_tensor.numel (),
868+ 0 .0f );
869+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " constant" );
870+ PADDLE_ENFORCE_XPU_SUCCESS (xpu_wait (dev_ctx->stream ()));
871+ }
858872 }
859873#elif defined(PADDLE_WITH_CNCL)
860874 if (platform::is_mlu_place (group_tensor.place ())) {
@@ -948,33 +962,7 @@ void Reducer::MarkGroupReady(size_t group_index) {
948962 // so we expose WaitCompute() interface and call
949963 // it here.
950964 parallel_ctx_->WaitCompute (run_order);
951- #ifdef PADDLE_WITH_XPU_BKCL
952- {
953- std::lock_guard<std::mutex> lock (mutex_);
954- comm_op_count_ += 1 ; // lock
955- }
956- // TODO(liuyuhui): Add try catch to deal with exception later,
957- // otherwise the main thread will continue to run when an exception is
958- // thrown in comm_pool_.
959- auto next_group = next_group_;
960- comm_pool_->enqueue ([this , run_order, next_group, &group] {
961- auto dev_id = place_.device ;
962- platform::SetXPUDeviceId (dev_id);
963- FusedAllReduceSchedule (run_order, group, next_group);
964- {
965- std::lock_guard<std::mutex> lock (mutex_);
966- comm_op_count_ -= 1 ; // lock
967- cv_.notify_all ();
968- }
969- });
970- #elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \
971- defined (PADDLE_WITH_GLOO) || defined (PADDLE_WITH_ASCEND_CL) || \
972- defined (PADDLE_WITH_CNCL)
973965 FusedAllReduceSchedule (run_order, group, next_group_);
974- #else
975- PADDLE_THROW (platform::errors::PreconditionNotMet (
976- " Not compiled with BKCL or NCCL or CNCL or GLOO." ));
977- #endif
978966 }
979967}
980968
@@ -997,17 +985,6 @@ void Reducer::FusedAllReduceSchedule(const int run_order,
997985 // group.dense_tensors ---> group.dense_contents_
998986 group.ConcatTensors (dev_context);
999987
1000- // NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support
1001- // default stream for communicating, so there exist some problems in
1002- // synchronization. And need to add a WaitComm there.
1003- // TODO(liuyuhui): If BKCL support non-blocking communication, it should be
1004- // fixed as multi gpus card training.
1005- #ifdef PADDLE_WITH_XPU_BKCL
1006- if (platform::is_xpu_place (group.dense_tensors_ [0 ].place ())) {
1007- parallel_ctx_->WaitComm (run_order);
1008- }
1009- #endif
1010-
1011988 group.DivNRanks (dev_context, nranks_);
1012989 // Start allreduce
1013990 parallel_ctx_->AllReduceByStream (
@@ -1135,12 +1112,6 @@ bool Reducer::HasGrad(size_t var_index) {
11351112void Reducer::FinalizeBackward () {
11361113 groups_need_finalize_ = false ;
11371114 grad_need_hooks_ = false ;
1138- #ifdef PADDLE_WITH_XPU_BKCL
1139- {
1140- std::unique_lock<std::mutex> lock (mutex_);
1141- cv_.wait (lock, [&] { return comm_op_count_ == 0 ; });
1142- }
1143- #endif
11441115
11451116 // Must prevent compute_stream_ starting until all comm streams have finished
11461117 for (int i = 0 ; i < nrings_; ++i) {
0 commit comments