Skip to content

Commit ba221cf

Browse files
committed
special deal with the inplace case in auto parallel pass.
1 parent fb2bd26 commit ba221cf

File tree

4 files changed

+62
-2
lines changed

4 files changed

+62
-2
lines changed

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@
158158
'c_allreduce_avg',
159159
'c_allreduce_max',
160160
'c_allreduce_min',
161-
'c_allreduce_sum',
162161
'c_allreduce_prod',
163162
'c_embedding',
164163
'c_identity',

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

+17
Original file line numberDiff line numberDiff line change
@@ -3866,6 +3866,7 @@ void AssignOut_Op::Build(pir::Builder &builder,
38663866
std::vector<pir::Type> argument_outputs =
38673867
AssignOut_Op::InferMeta(argument_inputs, &argument_attributes);
38683868
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
3869+
argument.AddAttributes(argument_attributes);
38693870
constexpr char kStopGradientAttrName[] = "stop_gradient";
38703871
auto stop_gradient0 =
38713872
argument.inputs[0].attribute<pir::BoolAttribute>(kStopGradientAttrName);
@@ -3970,6 +3971,22 @@ std::vector<pir::Type> AssignOut_Op::InferMeta(
39703971
dense_out.layout(),
39713972
dense_out.lod(),
39723973
dense_out.offset());
3974+
#ifdef PADDLE_WITH_DISTRIBUTE
3975+
// Auto Parallel condition
3976+
if (auto dist_type = input_values[1].type().dyn_cast<DistTypeInterface>()) {
3977+
ProcessMeshAttribute op_mesh = dist_type.process_mesh_attr();
3978+
auto ctx = pir::IrContext::Instance();
3979+
std::vector<pir::Attribute> dist_operand_attrs{
3980+
dist_type.tensor_dist_attr(),
3981+
dist_type.tensor_dist_attr(),
3982+
},
3983+
dist_result_attrs{dist_type.tensor_dist_attr()};
3984+
argument_outputs.push_back(dist_type);
3985+
(*p_attributes)[kAttrOpDistAttr] = OperationDistAttribute::get(
3986+
ctx, op_mesh, dist_operand_attrs, dist_result_attrs);
3987+
return argument_outputs;
3988+
}
3989+
#endif
39733990
argument_outputs.push_back(out_dense_tensor_type);
39743991

39753992
return argument_outputs;

python/paddle/distributed/auto_parallel/static/pir_pass.py

+44
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,50 @@ def apply_partition_pass(program):
6969
assert len(op.operands()) == len(
7070
op.dist_attr.operands()
7171
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"
72+
assert len(op.results()) == len(
73+
op.dist_attr.results()
74+
), f"The number of results and the number of op_dist_attr's results are not equal in op: {op}"
75+
# deal with inplace value
76+
for out_idx, in_idx in paddle.core.pir.get_op_inplace_info(op).items():
77+
operand = op.operand(in_idx)
78+
operand_attr = op.dist_attr.operand(in_idx)
79+
prev_var = operand.source()
80+
if not prev_var.is_dist() or operand_attr == prev_var.dist_attr():
81+
continue
82+
assert (
83+
not prev_var.is_combine()
84+
), f"The current partition pass not support inplace value of {op} is tensor list."
85+
operand_attr = operand_attr.as_tensor_dist_attr()
86+
# reshard input
87+
paddle.pir.set_insertion_point(op)
88+
reshard_var = paddle._C_ops.reshard_v2(prev_var, operand_attr)
89+
operand.set_source(reshard_var)
90+
91+
result = op.result(out_idx)
92+
result_attr = op.dist_attr.result(out_idx).as_tensor_dist_attr()
93+
assert (
94+
operand_attr == result_attr
95+
), f"For inplace value, The operend dist attr should be equal to result dist attr , please check your infer_spmd func of {op}"
96+
97+
# reshard output
98+
paddle.pir.set_insertion_point_after(op)
99+
old_dist_attr = result.dist_attr()
100+
result.update_dist_attr(result_attr)
101+
102+
# reshard output to assign out input
103+
reshard_var_1 = paddle._C_ops.reshard_v2(
104+
result, prev_var.dist_attr()
105+
)
106+
paddle.assign(reshard_var_1, prev_var)
107+
108+
if old_dist_attr == result.dist_attr():
109+
continue
110+
reshard_var_2 = reshard_var_1
111+
if old_dist_attr != reshard_var_1.dist_attr():
112+
reshard_var_2 = paddle._C_ops.reshard_v2(result, old_dist_attr)
113+
result.replace_all_uses_with(reshard_var_1)
114+
reshard_var_1.get_defining_op().operand(0).set_source(result)
115+
reshard_var_2.get_defining_op().operand(0).set_source(result)
72116

73117
for operand, attr in zip(op.operands(), op.dist_attr.operands()):
74118
prev_var = operand.source()

python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
4848
reduce_mean = True
4949

5050
group = new_process_group(sorted(src_mesh.process_ids))
51-
reduced_value = paddle._C_ops.c_allreduce_sum_(
51+
reduced_value = paddle._C_ops.c_allreduce_sum(
5252
src_value, group.id, True, False
5353
)
5454

0 commit comments

Comments
 (0)