Skip to content

Commit b8f1c52

Browse files
committed
special deal with the inplace case in auto parallel pass.
1 parent fb2bd26 commit b8f1c52

File tree

9 files changed

+78
-18
lines changed

9 files changed

+78
-18
lines changed

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@
158158
'c_allreduce_avg',
159159
'c_allreduce_max',
160160
'c_allreduce_min',
161-
'c_allreduce_sum',
162161
'c_allreduce_prod',
163162
'c_embedding',
164163
'c_identity',

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

+17
Original file line numberDiff line numberDiff line change
@@ -3866,6 +3866,7 @@ void AssignOut_Op::Build(pir::Builder &builder,
38663866
std::vector<pir::Type> argument_outputs =
38673867
AssignOut_Op::InferMeta(argument_inputs, &argument_attributes);
38683868
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
3869+
argument.AddAttributes(argument_attributes);
38693870
constexpr char kStopGradientAttrName[] = "stop_gradient";
38703871
auto stop_gradient0 =
38713872
argument.inputs[0].attribute<pir::BoolAttribute>(kStopGradientAttrName);
@@ -3970,6 +3971,22 @@ std::vector<pir::Type> AssignOut_Op::InferMeta(
39703971
dense_out.layout(),
39713972
dense_out.lod(),
39723973
dense_out.offset());
3974+
#ifdef PADDLE_WITH_DISTRIBUTE
3975+
// Auto Parallel condition
3976+
if (auto dist_type = input_values[1].type().dyn_cast<DistTypeInterface>()) {
3977+
ProcessMeshAttribute op_mesh = dist_type.process_mesh_attr();
3978+
auto ctx = pir::IrContext::Instance();
3979+
std::vector<pir::Attribute> dist_operand_attrs{
3980+
dist_type.tensor_dist_attr(),
3981+
dist_type.tensor_dist_attr(),
3982+
},
3983+
dist_result_attrs{dist_type.tensor_dist_attr()};
3984+
argument_outputs.push_back(dist_type);
3985+
(*p_attributes)[kAttrOpDistAttr] = OperationDistAttribute::get(
3986+
ctx, op_mesh, dist_operand_attrs, dist_result_attrs);
3987+
return argument_outputs;
3988+
}
3989+
#endif
39733990
argument_outputs.push_back(out_dense_tensor_type);
39743991

39753992
return argument_outputs;

python/paddle/distributed/auto_parallel/static/pir_pass.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,50 @@ def apply_partition_pass(program):
6969
assert len(op.operands()) == len(
7070
op.dist_attr.operands()
7171
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"
72+
assert len(op.results()) == len(
73+
op.dist_attr.results()
74+
), f"The number of results and the number of op_dist_attr's results are not equal in op: {op}"
75+
# deal with inplace value
76+
for out_idx, in_idx in paddle.core.pir.get_op_inplace_info(op).items():
77+
operand = op.operand(in_idx)
78+
operand_attr = op.dist_attr.operand(in_idx)
79+
prev_var = operand.source()
80+
if not prev_var.is_dist() or operand_attr == prev_var.dist_attr():
81+
continue
82+
assert (
83+
not prev_var.is_combine()
84+
), f"The current partition pass not support inplace value of {op} is tensor list."
85+
operand_attr = operand_attr.as_tensor_dist_attr()
86+
# reshard input
87+
paddle.pir.set_insertion_point(op)
88+
reshard_var = paddle._C_ops.reshard_v2(prev_var, operand_attr)
89+
operand.set_source(reshard_var)
90+
91+
result = op.result(out_idx)
92+
result_attr = op.dist_attr.result(out_idx).as_tensor_dist_attr()
93+
assert (
94+
operand_attr == result_attr
95+
), f"For inplace value, The operend dist attr should be equal to result dist attr , please check your infer_spmd func of {op}"
96+
97+
# reshard output
98+
paddle.pir.set_insertion_point_after(op)
99+
old_dist_attr = result.dist_attr()
100+
result.update_dist_attr(result_attr)
101+
102+
# reshard output to assign out input
103+
reshard_var_1 = paddle._C_ops.reshard_v2(
104+
result, prev_var.dist_attr()
105+
)
106+
paddle.assign(reshard_var_1, prev_var)
107+
108+
if old_dist_attr == result.dist_attr():
109+
continue
110+
reshard_var_2 = reshard_var_1
111+
if old_dist_attr != reshard_var_1.dist_attr():
112+
reshard_var_2 = paddle._C_ops.reshard_v2(result, old_dist_attr)
113+
result.replace_all_uses_with(reshard_var_1)
114+
reshard_var_1.get_defining_op().operand(0).set_source(result)
115+
reshard_var_2.get_defining_op().operand(0).set_source(result)
72116

73117
for operand, attr in zip(op.operands(), op.dist_attr.operands()):
74118
prev_var = operand.source()
@@ -187,7 +231,7 @@ def remove_other_rank_op_pass(dist_program):
187231

188232

189233
# Note: this is the pass in the dense program
190-
comm_ops = ["pd_op.c_allreduce_sum_", "pd_op.c_allgather"]
234+
comm_ops = ["pd_op.c_allreduce_sum", "pd_op.c_allgather"]
191235

192236

193237
def remove_unuseful_comm_op_pass(program):

python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
4848
reduce_mean = True
4949

5050
group = new_process_group(sorted(src_mesh.process_ids))
51-
reduced_value = paddle._C_ops.c_allreduce_sum_(
51+
reduced_value = paddle._C_ops.c_allreduce_sum(
5252
src_value, group.id, True, False
5353
)
5454

test/auto_parallel/hybrid_strategy/pir_reshard_nd_mesh.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def run_pp_to_rr_case(self):
9999
new_ops_name = [op.name() for op in dist_program.global_block().ops]
100100

101101
rank_id = dist.get_rank()
102-
assert new_ops_name[-2] == "pd_op.c_allreduce_sum_"
103-
assert new_ops_name[-1] == "pd_op.c_allreduce_sum_"
102+
assert new_ops_name[-2] == "pd_op.c_allreduce_sum"
103+
assert new_ops_name[-1] == "pd_op.c_allreduce_sum"
104104

105105
# check the first allreduce_sum
106106
op = new_ops[-2]
@@ -151,11 +151,11 @@ def run_pr_to_rs_case(self):
151151
new_ops_name = [op.name() for op in dist_program.global_block().ops]
152152

153153
rank_id = dist.get_rank()
154-
assert "pd_op.c_allreduce_sum_" in new_ops_name
154+
assert "pd_op.c_allreduce_sum" in new_ops_name
155155
assert new_ops_name[-1] == "pd_op.slice"
156156

157157
# check the allreduce_sum
158-
op = new_ops[new_ops_name.index("pd_op.c_allreduce_sum_")]
158+
op = new_ops[new_ops_name.index("pd_op.c_allreduce_sum")]
159159
if rank_id == 0 or rank_id == 2:
160160
process_ids = [0, 2]
161161
elif rank_id == 1 or rank_id == 3:
@@ -278,11 +278,11 @@ def run_ps_to_ps_case(self):
278278

279279
ops = dist_program.global_block().ops
280280
op_names = [op.name() for op in ops]
281-
assert "pd_op.c_allreduce_sum_" in op_names
281+
assert "pd_op.c_allreduce_sum" in op_names
282282
assert "pd_op.c_allgather" in op_names
283283
assert "pd_op.slice" in op_names
284284

285-
allreduce_sum_op = ops[op_names.index("pd_op.c_allreduce_sum_")]
285+
allreduce_sum_op = ops[op_names.index("pd_op.c_allreduce_sum")]
286286
allgather_op = ops[op_names.index("pd_op.c_allgather")]
287287
slice_op = ops[op_names.index("pd_op.slice")]
288288

test/auto_parallel/hybrid_strategy/pir_reshard_nd_mesh_cross_mesh.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def run_pp_to_rr_case(self):
105105
assert new_ops_name[2] == "pd_op.send_v2"
106106
else:
107107
assert new_ops_name[2] == "pd_op.recv_v2"
108-
assert new_ops_name[-2] == "pd_op.c_allreduce_sum_"
109-
assert new_ops_name[-1] == "pd_op.c_allreduce_sum_"
108+
assert new_ops_name[-2] == "pd_op.c_allreduce_sum"
109+
assert new_ops_name[-1] == "pd_op.c_allreduce_sum"
110110

111111
# check the first allreduce_sum
112112
op = new_ops[-2]

test/auto_parallel/pir/test_to_static_pir_program.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_to_static_program(self):
134134
"pd_op.sgd_",
135135
"pd_op.sgd_",
136136
"pd_op.relu_grad",
137-
"pd_op.c_allreduce_sum_",
137+
"pd_op.c_allreduce_sum",
138138
"pd_op.matmul_grad",
139139
"pd_op.relu_grad",
140140
"pd_op.matmul_grad",

test/auto_parallel/reshard_p_to_r.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ def run_pir_static_test_case(self):
9797
'builtin.parameter',
9898
'pd_op.data',
9999
'dist_op.shard_tensor',
100-
'pd_op.c_allreduce_sum_',
100+
'pd_op.c_allreduce_sum',
101101
],
102102
)
103103

104104
for op in ops:
105-
if op.name() == 'pd_op.c_allreduce_sum_':
105+
if op.name() == 'pd_op.c_allreduce_sum':
106106
# check op dist_attr
107107
assert op.dist_attr.num_operands() == 1
108108
assert op.dist_attr.num_results() == 1
@@ -167,7 +167,7 @@ def run_pir_to_static_test_case(self):
167167
"pd_op.sgd_",
168168
"pd_op.sgd_",
169169
"pd_op.relu_grad",
170-
"pd_op.c_allreduce_sum_",
170+
"pd_op.c_allreduce_sum",
171171
"pd_op.matmul_grad",
172172
"pd_op.relu_grad",
173173
"pd_op.matmul_grad",

test/auto_parallel/reshard_p_to_r_cross_mesh.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run_pir_static_test_case(self):
9797
'dist_op.shard_tensor',
9898
'pd_op.send_v2',
9999
'dist_op.reshard',
100-
'pd_op.c_allreduce_sum_',
100+
'pd_op.c_allreduce_sum',
101101
]
102102
else:
103103
np.testing.assert_equal(main_program.num_ops(), 5)
@@ -106,7 +106,7 @@ def run_pir_static_test_case(self):
106106
'pd_op.data',
107107
'dist_op.shard_tensor',
108108
'pd_op.recv_v2',
109-
'pd_op.c_allreduce_sum_',
109+
'pd_op.c_allreduce_sum',
110110
]
111111
np.testing.assert_equal(
112112
ops,
@@ -141,7 +141,7 @@ def run_pir_static_test_case(self):
141141
assert op_result_dist_attr.partial_status == {
142142
0: paddle.distributed.ReduceType.kRedSum
143143
}
144-
elif op.name() == 'pd_op.c_allreduce_sum_':
144+
elif op.name() == 'pd_op.c_allreduce_sum':
145145
continue
146146
# check op dist_attr
147147
assert op.dist_attr.num_operands() == 1

0 commit comments

Comments
 (0)