@@ -69,6 +69,50 @@ def apply_partition_pass(program):
69
69
assert len (op .operands ()) == len (
70
70
op .dist_attr .operands ()
71
71
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: { op } "
72
+ assert len (op .results ()) == len (
73
+ op .dist_attr .results ()
74
+ ), f"The number of results and the number of op_dist_attr's results are not equal in op: { op } "
75
+ # deal with inplace value
76
+ for out_idx , in_idx in paddle .core .pir .get_op_inplace_info (op ).items ():
77
+ operand = op .operand (in_idx )
78
+ operand_attr = op .dist_attr .operand (in_idx )
79
+ prev_var = operand .source ()
80
+ if not prev_var .is_dist () or operand_attr == prev_var .dist_attr ():
81
+ continue
82
+ assert (
83
+ not prev_var .is_combine ()
84
+ ), f"The current partition pass not support inplace value of { op } is tensor list."
85
+ operand_attr = operand_attr .as_tensor_dist_attr ()
86
+ # reshard input
87
+ paddle .pir .set_insertion_point (op )
88
+ reshard_var = paddle ._C_ops .reshard_v2 (prev_var , operand_attr )
89
+ operand .set_source (reshard_var )
90
+
91
+ result = op .result (out_idx )
92
+ result_attr = op .dist_attr .result (out_idx ).as_tensor_dist_attr ()
93
+ assert (
94
+ operand_attr == result_attr
95
+ ), f"For inplace value, The operend dist attr should be equal to result dist attr , please check your infer_spmd func of { op } "
96
+
97
+ # reshard output
98
+ paddle .pir .set_insertion_point_after (op )
99
+ old_dist_attr = result .dist_attr ()
100
+ result .update_dist_attr (result_attr )
101
+
102
+ # reshard output to assign out input
103
+ reshard_var_1 = paddle ._C_ops .reshard_v2 (
104
+ result , prev_var .dist_attr ()
105
+ )
106
+ paddle .assign (reshard_var_1 , prev_var )
107
+
108
+ if old_dist_attr == result .dist_attr ():
109
+ continue
110
+ reshard_var_2 = reshard_var_1
111
+ if old_dist_attr != reshard_var_1 .dist_attr ():
112
+ reshard_var_2 = paddle ._C_ops .reshard_v2 (result , old_dist_attr )
113
+ result .replace_all_uses_with (reshard_var_1 )
114
+ reshard_var_1 .get_defining_op ().operand (0 ).set_source (result )
115
+ reshard_var_2 .get_defining_op ().operand (0 ).set_source (result )
72
116
73
117
for operand , attr in zip (op .operands (), op .dist_attr .operands ()):
74
118
prev_var = operand .source ()
0 commit comments