Skip to content

Commit c40fb85

Browse files
authored
Support XPU for dygraph auto-parallel (#70997)
* Support XPU for dygraph auto-parallel * Fix * Fix Custom devices * Fix code style
1 parent bf081f4 commit c40fb85

11 files changed

+190
-109
lines changed

paddle/phi/core/distributed/auto_parallel/reshard/global_and_sub_mesh_reshard_function.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ void SubMeshToGlobalReshardFunction::Eval(phi::DeviceContext* dev_ctx,
7979
const TensorDistAttr& out_dist_attr,
8080
DistTensor* out) {
8181
VLOG(3) << "Call SubMeshToGlobalReshardFunction Eval";
82+
#if defined(PADDLE_WITH_XPU)
83+
PADDLE_THROW(::common::errors::Unimplemented(
84+
"Not supported PSendKernel/PRecv on xpu yet."));
85+
#else
8286
const TensorDistAttr& in_dist_attr = in.dist_attr();
8387
const ProcessMesh& in_process_mesh = in_dist_attr.process_mesh();
8488
const ProcessMesh& out_process_mesh = out_dist_attr.process_mesh();
@@ -132,6 +136,7 @@ void SubMeshToGlobalReshardFunction::Eval(phi::DeviceContext* dev_ctx,
132136
GetMutableTensor(out));
133137
}
134138
SetDistProps(out, in.dims(), out_dist_attr);
139+
#endif
135140
}
136141

137142
} // namespace phi::distributed

paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ void ReshardPToSWithPadding(DeviceContext* dev_ctx,
6767
}
6868

6969
DenseTensor out_reduce_scatter;
70+
#if defined(PADDLE_WITH_XPU)
71+
PADDLE_THROW(::common::errors::Unimplemented(
72+
"Not supported Reducescatter on xpu yet."));
73+
#else
7074
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
7175
ReduceScatter,
7276
dtype,
7377
process_ids,
7478
in_reduce_scatter,
7579
static_cast<int64_t>(process_ids.size()),
7680
&out_reduce_scatter);
77-
81+
#endif
7882
DenseTensor out_result;
7983
if (split_axis != 0) {
8084
RESHARD_FUNCTOR(

paddle/phi/core/distributed/auto_parallel/reshard/r_to_x_reshard_function.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ void RToXExpandReshardFunction::Eval(phi::DeviceContext* dev_ctx,
6060
int64_t cur_global_rank = GetCurGlobalRank();
6161
int64_t root_rank = in_process_ids[0];
6262
auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids);
63-
bool dynamic_shape = true;
6463
auto dtype = in.dtype();
6564
const auto& out_partial_status = out_dist_attr.partial_status();
6665
bool cur_rank_in_out_mesh =
@@ -72,27 +71,37 @@ void RToXExpandReshardFunction::Eval(phi::DeviceContext* dev_ctx,
7271
if (root_rank == cur_global_rank) {
7372
for (const auto& out_process_id : out_process_ids) {
7473
if (out_process_id != root_rank) {
74+
#if defined(PADDLE_WITH_XPU)
75+
PADDLE_THROW(::common::errors::Unimplemented(
76+
"Not supported PSendKernel on xpu yet."));
77+
#else
7578
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
7679
PSendKernel,
7780
dtype,
7881
all_process_ids,
7982
in.value(),
8083
out_process_id,
81-
dynamic_shape);
84+
/*dynamic_shape=*/true);
85+
#endif
8286
}
8387
}
8488
if (cur_rank_in_out_mesh) {
8589
result_value = in.value();
8690
}
8791
} else {
92+
#if defined(PADDLE_WITH_XPU)
93+
PADDLE_THROW(
94+
::common::errors::Unimplemented("Not supported PRecv on xpu yet."));
95+
#else
8896
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
8997
PRecv,
9098
dtype,
9199
all_process_ids,
92100
root_rank,
93101
{} /*out_shape*/,
94-
dynamic_shape,
102+
/*dynamic_shape=*/true,
95103
&result_value);
104+
#endif
96105
}
97106

98107
if (cur_rank_in_out_mesh) {

paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,41 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
109109
PADDLE_THROW(common::errors::Unimplemented(
110110
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
111111
#endif
112-
} else if (phi::CustomContext::classof(&dev_ctx)) {
113-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
114-
CommContextManager::CreateXCCLCommContext(
115-
store, unique_comm_key, dev_ctx.GetPlace(), rank, world_size);
116-
#endif
117-
} else {
112+
}
113+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
114+
else if (phi::GPUContext::classof(&dev_ctx)) { // NOLINT
118115
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
119-
if (phi::GPUContext::classof(&dev_ctx)) {
120-
CommContextManager::CreateNCCLCommContext(store,
121-
unique_comm_key,
122-
static_cast<int>(rank),
123-
static_cast<int>(world_size));
124-
}
116+
CommContextManager::CreateNCCLCommContext(store,
117+
unique_comm_key,
118+
static_cast<int>(rank),
119+
static_cast<int>(world_size));
125120
#else
126121
PADDLE_THROW(common::errors::Unimplemented(
127-
"CommContext is only supported on CPU and GPU for now, other devices "
128-
"will be supported later."));
122+
"Cannot use nccl on GPU, please turn WITH_NCCL flag on."));
129123
#endif
130124
}
125+
#elif defined(PADDLE_WITH_XPU)
126+
else if (phi::XPUContext::classof(&dev_ctx)) { // NOLINT
127+
#if defined(PADDLE_WITH_XPU_BKCL)
128+
CommContextManager::CreateBKCLCommContext(store,
129+
unique_comm_key,
130+
static_cast<int>(rank),
131+
static_cast<int>(world_size));
132+
#else
133+
PADDLE_THROW(common::errors::Unimplemented(
134+
"Cannot use xpu on GPU, please turn WITH_XPU_BKCL flag on."));
135+
#endif
136+
}
137+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
138+
else if (phi::CustomContext::classof(&dev_ctx)) { // NOLINT
139+
CommContextManager::CreateXCCLCommContext(
140+
store, unique_comm_key, dev_ctx.GetPlace(), rank, world_size);
141+
}
142+
#endif
143+
else { // NOLINT
144+
PADDLE_THROW(common::errors::Unimplemented(
145+
"CommContext is only supported CPU, GPU, XPU, and CustomDevice."));
146+
}
131147
}
132148

133149
auto* comm_context = CommContextManager::GetInstance().Get(unique_comm_key);

paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,63 @@ phi::DDim InferShapeForReshardFromReplicate(
7979
const TensorDistAttr& dist_attr);
8080

8181
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
82-
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
83-
do { \
84-
if (phi::CPUContext::classof(dev_ctx)) { \
85-
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
86-
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU( \
87-
dtype, #fn_name, ([&] { \
88-
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
89-
__VA_ARGS__); \
90-
})); \
91-
} else if (phi::GPUContext::classof(dev_ctx)) { \
92-
VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \
93-
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_GPU( \
94-
dtype, #fn_name, ([&] { \
95-
fn_name<data_t>(static_cast<const GPUContext&>(*dev_ctx), \
96-
__VA_ARGS__); \
97-
})); \
98-
} else { \
99-
PADDLE_THROW(common::errors::Unimplemented( \
100-
"The %s in reshard only supported on CPU and GPU for now.", \
101-
#fn_name)); \
102-
} \
82+
#define DEVICE_CONTEXT GPUContext
83+
#elif defined(PADDLE_WITH_XPU)
84+
#define DEVICE_CONTEXT XPUContext
85+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
86+
#define DEVICE_CONTEXT CustomContext
87+
#endif
88+
89+
// Some reshard function supports fewer data types on xpu than on gpu. For
90+
// example, `Transpose`, `Split`, and `Divide` do not support double type.
91+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
92+
#define PD_VISIT_RESHARD_TYPES PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES
93+
#else
94+
#define PD_VISIT_RESHARD_TYPES(TYPE, NAME, ...) \
95+
[&] { \
96+
const auto& __dtype__ = TYPE; \
97+
switch (__dtype__) { \
98+
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
99+
PD_PRIVATE_CASE_TYPE( \
100+
NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \
101+
PD_PRIVATE_CASE_TYPE( \
102+
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
103+
PD_PRIVATE_CASE_TYPE( \
104+
NAME, ::paddle::DataType::FLOAT16, paddle::float16, __VA_ARGS__) \
105+
PD_PRIVATE_CASE_TYPE_BFLOAT16(NAME, __VA_ARGS__) \
106+
default: \
107+
PD_THROW("Reshard function " #NAME \
108+
" is not implemented" \
109+
" for data type `", \
110+
__dtype__, \
111+
"`"); \
112+
} \
113+
}()
114+
#endif
115+
116+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
117+
defined(PADDLE_WITH_XPU)
118+
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
119+
do { \
120+
if (phi::CPUContext::classof(dev_ctx)) { \
121+
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
122+
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU( \
123+
dtype, #fn_name, ([&] { \
124+
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
125+
__VA_ARGS__); \
126+
})); \
127+
} else if (DEVICE_CONTEXT::classof(dev_ctx)) { \
128+
VLOG(4) << "Call `" << #fn_name << "` in Resharding on device."; \
129+
PD_VISIT_RESHARD_TYPES( \
130+
dtype, #fn_name, ([&] { \
131+
fn_name<data_t>(static_cast<const DEVICE_CONTEXT&>(*dev_ctx), \
132+
__VA_ARGS__); \
133+
})); \
134+
} else { \
135+
PADDLE_THROW(common::errors::Unimplemented( \
136+
"The %s in reshard only supported on CPU, GPU, and XPU for now.", \
137+
#fn_name)); \
138+
} \
103139
} while (0)
104140
#else
105141
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
@@ -130,35 +166,37 @@ phi::DDim InferShapeForReshardFromReplicate(
130166
RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, __VA_ARGS__); \
131167
} while (0)
132168

133-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
134-
#define RESHARD_FUNCTOR_WITHOUT_DTYPE(dev_ctx, fn_name, ...) \
135-
do { \
136-
if (phi::CPUContext::classof(dev_ctx)) { \
137-
VLOG(4) << "Call `" << #fn_name \
138-
<< "`without DType in Resharding on CPU."; \
139-
fn_name(static_cast<const CPUContext&>(*dev_ctx), __VA_ARGS__); \
140-
} else if (phi::GPUContext::classof(dev_ctx)) { \
141-
VLOG(4) << "Call `" << #fn_name \
142-
<< "`without DType in Resharding on GPU."; \
143-
fn_name(static_cast<const GPUContext&>(*dev_ctx), __VA_ARGS__); \
144-
} else { \
145-
PADDLE_THROW(common::errors::Unimplemented( \
146-
"The %s in reshard only supported on CPU and GPU for now.", \
147-
#fn_name)); \
148-
} \
149-
} while (0)
150-
#else
169+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
170+
defined(PADDLE_WITH_XPU)
151171
#define RESHARD_FUNCTOR_WITHOUT_DTYPE(dev_ctx, fn_name, ...) \
152172
do { \
153173
if (phi::CPUContext::classof(dev_ctx)) { \
154174
VLOG(4) << "Call `" << #fn_name \
155-
<< "`without DType in Resharding on CPU."; \
175+
<< "`without DType in Resharding on CPU."; \
156176
fn_name(static_cast<const CPUContext&>(*dev_ctx), __VA_ARGS__); \
177+
} else if (DEVICE_CONTEXT::classof(dev_ctx)) { \
178+
VLOG(4) << "Call `" << #fn_name \
179+
<< "`without DType in Resharding on device."; \
180+
fn_name(static_cast<const DEVICE_CONTEXT&>(*dev_ctx), __VA_ARGS__); \
157181
} else { \
158182
PADDLE_THROW(common::errors::Unimplemented( \
159-
"The %s in reshard only supported on CPU for now.", #fn_name)); \
183+
"The %s in reshard only supported CPU, GPU, and XPU Device", \
184+
#fn_name)); \
160185
} \
161186
} while (0)
187+
#else
188+
#define RESHARD_FUNCTOR_WITHOUT_DTYPE(dev_ctx, fn_name, ...) \
189+
do { \
190+
if (phi::CPUContext::classof(dev_ctx)) { \
191+
VLOG(4) << "Call `" << #fn_name \
192+
<< "`without DType in Resharding on CPU."; \
193+
fn_name(static_cast<const CPUContext&>(*dev_ctx), __VA_ARGS__); \
194+
} else { \
195+
PADDLE_THROW(common::errors::Unimplemented( \
196+
"The %s in reshard only supported CPU, GPU, and XPU Device.", \
197+
#fn_name)); \
198+
} \
199+
} while (0)
162200
#endif
163201

164202
#define RESHARD_SHORTCUT_IF_FALSE(expr) \

paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,13 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx,
4242
// For balanced split to replicate, we need to do all gather first.
4343
// If the input value doesn't split on axis 0, we need to split
4444
// and concat on specific axis.
45+
#if defined(PADDLE_WITH_XPU)
46+
PADDLE_THROW(
47+
::common::errors::Unimplemented("Not supported AllGather on xpu yet."));
48+
#else
4549
RESHARD_FUNCTOR_WITH_COMM(
4650
dev_ctx, AllGather, dtype, process_ids, in, num_of_process, out);
51+
#endif
4752

4853
if (split_axis != 0 || padding_nums != 0) {
4954
IntArray sections(std::vector<int64_t>(num_of_process, in.dims()[0]));

paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,17 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
9797
}
9898

9999
// 2. use all to all to switch data to other ranks
100+
#if defined(PADDLE_WITH_XPU)
101+
PADDLE_THROW(
102+
::common::errors::Unimplemented("Not supported AllToAll on xpu yet."));
103+
#else
100104
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
101105
AllToAll,
102106
dtype,
103107
in_process_ids,
104108
in_all_to_all,
105109
GetMutableTensor(out));
110+
#endif
106111

107112
// 3. postprocess, reshape and transpose the output tensor
108113
if (in_split_axis != 0) {

paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
5555
const auto& out_process_mesh = out_dist_attr.process_mesh();
5656
const auto& out_process_ids = out_process_mesh.process_ids();
5757
auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids);
58-
auto dtype = in.dtype();
59-
// TODO(liyurui): Use dynamic shape will lead to poor performance, but we
60-
// don't have any other good idea now. For the following reasons:
61-
// 1. We can not ensure the meta being right deduce by the infermeta.
62-
// 2. The meta of some kernels can't decide in compile time.
63-
// 3. DenseTensor with empty value only need infermeta and skip the real
64-
// kernel execution.
65-
bool dynamic_shape = true;
6658

6759
// TODO(GhostScreaming): After cross-mesh reshard, current device may
6860
// needs to execute next layer. When it construct next layer's backward
@@ -86,28 +78,44 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
8678
int64_t src = iter.first;
8779
int64_t dst = iter.second;
8880
if (src == cur_global_rank) {
81+
#if defined(PADDLE_WITH_XPU)
82+
PADDLE_THROW(::common::errors::Unimplemented(
83+
"Not supported PSendKernel on xpu yet."));
84+
#else
8985
VLOG(3) << "Send from src " << src << " to dst " << dst;
9086
int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst);
9187
// Since send kernel only has input, so we don't need to infermeta
9288
// actually. According to this reason, just use the kernel directly.
9389
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
9490
PSendKernel,
95-
dtype,
91+
in.dtype(),
9692
all_process_ids,
9793
in.value(),
9894
dst_local_rank,
99-
dynamic_shape);
95+
/*dynamic_shape=*/true);
96+
// TODO(liyurui): Use dynamic shape will lead to poor performance, but we
97+
// don't have any other good idea now. For the following reasons:
98+
// 1. We can not ensure the meta being right deduce by the infermeta.
99+
// 2. The meta of some kernels can't decide in compile time.
100+
// 3. DenseTensor with empty value only need infermeta and skip the real
101+
// kernel execution.
102+
#endif
100103
} else if (dst == cur_global_rank) {
104+
#if defined(PADDLE_WITH_XPU)
105+
PADDLE_THROW(::common::errors::Unimplemented(
106+
"Not supported PRecvKernel on xpu yet."));
107+
#else
101108
VLOG(3) << "Recv from src " << src << " to dst " << dst;
102109
int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src);
103110
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
104111
PRecv,
105-
dtype,
112+
in.dtype(),
106113
all_process_ids,
107114
src_local_rank,
108115
{} /*out_shape*/,
109-
dynamic_shape,
116+
/*dynamic_shape=*/true,
110117
GetMutableTensor(out));
118+
#endif
111119
}
112120
}
113121
SetDistProps(out, in.dims(), out_dist_attr);

0 commit comments

Comments
 (0)