Skip to content

Commit 8718d78

Browse files
authored
[Dist Dialect] Add MoE-related api in PIR dist dialect (#66462)
* add two MoE api in distributed dialect * polish the dist_op and add unit test * remove simple_net_ep unit test * remove redundant print * bug fix, replace platform::errors with phi::errors
1 parent dbfc48b commit 8718d78

File tree

12 files changed

+833
-2
lines changed

12 files changed

+833
-2
lines changed

paddle/fluid/pir/dialect/distributed/ir/dist_api.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,53 @@ pir::Value reshard(const pir::Value& x,
6363
return reshard_op.result(0);
6464
}
6565

66+
std::vector<pir::Value> local_tensors_from_dist(
67+
const pir::Value& input,
68+
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
69+
const std::vector<int64_t>& local_dims_mapping,
70+
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
71+
const phi::distributed::ProcessMesh& global_mesh,
72+
const std::vector<int64_t>& global_dims_mapping,
73+
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status) {
74+
pir::IrContext* ctx = pir::IrContext::Instance();
75+
std::vector<TensorDistAttribute> local_dist_attrs;
76+
for (const phi::distributed::ProcessMesh& mesh : local_mesh_list) {
77+
local_dist_attrs.emplace_back(TensorDistAttribute::get(
78+
ctx, mesh, local_dims_mapping, local_partial_status));
79+
}
80+
TensorDistAttribute global_dist_attr = TensorDistAttribute::get(
81+
ctx, global_mesh, global_dims_mapping, global_partial_status);
82+
83+
auto op = ApiBuilder::Instance().GetBuilder()->Build<LocalTensorsFromDistOp>(
84+
input, local_dist_attrs, global_dist_attr);
85+
return op.results();
86+
}
87+
88+
pir::Value dist_tensor_from_locals(
89+
const std::vector<pir::Value>& inputs,
90+
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
91+
const std::vector<int64_t>& local_dims_mapping,
92+
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
93+
const phi::distributed::ProcessMesh& global_mesh,
94+
const std::vector<int64_t>& global_dims_mapping,
95+
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status,
96+
const std::vector<int64_t>& global_shape) {
97+
pir::IrContext* ctx = pir::IrContext::Instance();
98+
99+
std::vector<TensorDistAttribute> local_dist_attrs;
100+
for (const phi::distributed::ProcessMesh& mesh : local_mesh_list) {
101+
local_dist_attrs.emplace_back(TensorDistAttribute::get(
102+
ctx, mesh, local_dims_mapping, local_partial_status));
103+
}
104+
105+
TensorDistAttribute global_dist_attr = TensorDistAttribute::get(
106+
ctx, global_mesh, global_dims_mapping, global_partial_status);
107+
108+
phi::DDim global_ddim = phi::make_ddim(global_shape);
109+
110+
auto op = ApiBuilder::Instance().GetBuilder()->Build<DistTensorFromLocalsOp>(
111+
inputs, local_dist_attrs, global_dist_attr, global_ddim);
112+
return op.result(0);
113+
}
114+
66115
} // namespace paddle::dialect

paddle/fluid/pir/dialect/distributed/ir/dist_api.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,24 @@ pir::Value reshard(
4040
pir::Value reshard(const pir::Value& x,
4141
const TensorDistAttribute& tensor_dist_attr);
4242

43+
std::vector<pir::Value> local_tensors_from_dist(
44+
const pir::Value& input,
45+
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
46+
const std::vector<int64_t>& local_dims_mapping,
47+
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
48+
const phi::distributed::ProcessMesh& global_mesh,
49+
const std::vector<int64_t>& global_dims_mapping,
50+
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status);
51+
52+
pir::Value dist_tensor_from_locals(
53+
const std::vector<pir::Value>& inputs,
54+
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
55+
const std::vector<int64_t>& local_dims_mapping,
56+
const flat_hash_map<int64_t, phi::ReduceType>& local_partial_status,
57+
const phi::distributed::ProcessMesh& global_mesh,
58+
const std::vector<int64_t>& global_dims_mapping,
59+
const flat_hash_map<int64_t, phi::ReduceType>& global_partial_status,
60+
const std::vector<int64_t>& global_shape);
61+
4362
} // namespace dialect
4463
} // namespace paddle

paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ void DistDialect::initialize() {
3535
TensorDistAttribute,
3636
OperationDistAttribute>();
3737
RegisterTypes<DistDenseTensorType>();
38-
RegisterOps<ShardTensorOp, ReshardOp>();
38+
RegisterOps<ShardTensorOp,
39+
ReshardOp,
40+
LocalTensorsFromDistOp,
41+
DistTensorFromLocalsOp>();
3942
}
4043

4144
void DistDialect::PrintType(pir::Type type, std::ostream &os) const {

0 commit comments

Comments
 (0)