Skip to content

Commit b3bdf7e

Browse files
authored
[AutoParallel] Add einsum spmd rules (#73753)
* add einsum rule * update cmakelists * fix test * apply suggestion * re-run ci
1 parent 1a09bd1 commit b3bdf7e

File tree

8 files changed

+996
-0
lines changed

8 files changed

+996
-0
lines changed
Lines changed: 390 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,390 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include <unordered_map>
13+
#include <unordered_set>
14+
15+
#include "glog/logging.h"
16+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
17+
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
19+
#include "paddle/phi/infermeta/spmd_rules/einsum.h"
20+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
21+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
22+
#include "paddle/utils/string/string_helper.h"
23+
24+
namespace phi::distributed {
25+
26+
using phi::distributed::auto_parallel::str_join;
27+
void ParseEinsumEquation(const std::string& equation,
28+
std::vector<std::string>* operands,
29+
std::string* output) {
30+
auto results = paddle::string::split_string(equation, "->");
31+
auto left = results[0];
32+
*operands = paddle::string::split_string(left, ",");
33+
*output = results[1];
34+
}
35+
36+
void ConstraintOnDiagLabel(std::vector<std::string>* operands,
37+
std::string* output) {
38+
// Empirically, for fwd calculation, only those diagonal labels in output
39+
// should not be sharded. e.g. iji->ii (diag), 'i' cannot be sharded;
40+
// e.g. iji->i (trace), 'i' can be sharded.
41+
// But during bwd calculation, input and output are switched.
42+
// e.g. in the 'trace' case above when calculating x_grad, it will use
43+
// i->ii, so 'i' cannot be sharded.
44+
// Thus we simply set the spmd rule here to replace all diagonal labels as 1.
45+
46+
// find diagonal labels
47+
std::unordered_map<char, int> char_count;
48+
std::unordered_set<char> diagonal_labels;
49+
for (auto op : *operands) {
50+
for (char c : op) {
51+
char_count[c]++;
52+
if (char_count[c] > 1) {
53+
diagonal_labels.insert(c);
54+
}
55+
}
56+
char_count.clear();
57+
}
58+
for (char c : *output) {
59+
char_count[c]++;
60+
if (char_count[c] > 1) {
61+
diagonal_labels.insert(c);
62+
}
63+
}
64+
65+
if (diagonal_labels.size()) {
66+
// replace input operands' diagonal labels
67+
for (size_t i = 0; i < operands->size(); ++i) {
68+
for (size_t j = 0; j < (*operands)[i].size(); ++j) {
69+
if (diagonal_labels.find((*operands)[i][j]) != diagonal_labels.end()) {
70+
(*operands)[i].replace(j, 1, "1");
71+
}
72+
}
73+
}
74+
// replace output's diagonal labels
75+
for (size_t i = 0; i < output->size(); ++i) {
76+
if (diagonal_labels.find((*output)[i]) != diagonal_labels.end()) {
77+
output->replace(i, 1, "1");
78+
}
79+
}
80+
}
81+
}
82+
83+
bool IsEinsumOuter(const std::vector<std::string>& inputs,
84+
const std::string& output) {
85+
// Outer case: e.g. i, j -> ij; ij, kl -> ijkl
86+
if (inputs.size() != 2) {
87+
return false;
88+
}
89+
90+
std::unordered_map<char, int> input_char_count;
91+
for (const auto& in : inputs) {
92+
for (char c : in) {
93+
input_char_count[c]++;
94+
if (input_char_count[c] > 1) {
95+
return false;
96+
}
97+
}
98+
}
99+
100+
std::unordered_map<char, int> output_char_count;
101+
for (char c : output) {
102+
output_char_count[c]++;
103+
}
104+
if (input_char_count != output_char_count) {
105+
return false;
106+
}
107+
return true;
108+
}
109+
110+
void ConstraintOnOuter(const phi::distributed::TensorDistAttr& x_attr,
111+
const phi::distributed::TensorDistAttr& y_attr,
112+
int x_ndim,
113+
int y_ndim,
114+
std::vector<int64_t>* x_dims_mapping,
115+
std::vector<int64_t>* y_dims_mapping) {
116+
// For outer operation, only one operand and one dimension can be sharded
117+
// todo: if multiple dimensions are requested to be sharded, decide which
118+
// operand and which dimension to be sharded could be better
119+
120+
// we simply choose the first operand requested to be sharded and the
121+
// first dimension requested to be sharded here
122+
if (x_attr.is_shard()) {
123+
bool meet_shard_axis = false;
124+
for (int i = 0; i < x_ndim; ++i) {
125+
if ((*x_dims_mapping)[i] != -1) {
126+
meet_shard_axis = true;
127+
continue;
128+
}
129+
if (meet_shard_axis) {
130+
(*x_dims_mapping)[i] = -1;
131+
}
132+
}
133+
// reset y_dims_mapping to all replicated
134+
for (int i = 0; i < y_ndim; ++i) {
135+
(*y_dims_mapping)[i] = -1;
136+
}
137+
} else if (y_attr.is_shard()) {
138+
bool meet_shard_axis = false;
139+
for (int i = 0; i < y_ndim; ++i) {
140+
if ((*y_dims_mapping)[i] != -1) {
141+
meet_shard_axis = true;
142+
continue;
143+
}
144+
if (meet_shard_axis) {
145+
(*y_dims_mapping)[i] = -1;
146+
}
147+
}
148+
// no need to reset x_dims_mapping
149+
}
150+
}
151+
152+
SpmdInfo EinsumInferSpmd(const std::vector<DistMetaTensor>& inputs,
153+
const std::string& equation) {
154+
PADDLE_ENFORCE_LE(
155+
inputs.size(),
156+
2,
157+
common::errors::InvalidArgument(
158+
"EinsumOp only support len(operands) between (0, 2]. Use "
159+
"opt_einsum first to convert multi-variable to binary-variable."));
160+
161+
std::vector<std::string> operands;
162+
std::string right;
163+
// ellipsis labels are already parsed in python API (einsum_v2)
164+
ParseEinsumEquation(equation, &operands, &right);
165+
// diagonal case
166+
ConstraintOnDiagLabel(&operands, &right);
167+
168+
if (inputs.size() == 1) {
169+
// single operand
170+
DistMetaTensor x = inputs[0];
171+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
172+
std::vector<int64_t> x_dims_mapping(x_dims_mapping_src);
173+
174+
VLOG(6) << "EinsumInferSpmd InferForward Inputs: "
175+
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
176+
<< str_join(x_dims_mapping);
177+
178+
// Step1: Sharding Propagation
179+
// Step1.1: Merge input shardings
180+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
181+
ShardingMergeForTensors({{operands[0], x_dims_mapping}});
182+
183+
// Step1.2: Infer output dims mapping
184+
TensorDistAttr x_dist_attr_dst =
185+
CopyTensorDistAttrForOutput(x_dist_attr_src);
186+
x_dist_attr_dst.set_dims_mapping(
187+
GetDimsMappingForAxes(operands[0], axis_to_dim_map));
188+
189+
std::vector<int64_t> fake_output_shape(right.size(), 1);
190+
TensorDistAttr out_dist_attr_dst(fake_output_shape);
191+
out_dist_attr_dst.set_process_mesh(x_dist_attr_src.process_mesh());
192+
out_dist_attr_dst.set_dims_mapping(
193+
GetDimsMappingForAxes(right, axis_to_dim_map));
194+
195+
// Step2: Handle Partial
196+
// Step2.1 Output Partial
197+
std::vector<int64_t> partial_on_dims =
198+
ResoluteOutputPartialDimension(axis_to_dim_map, right);
199+
out_dist_attr_dst.set_partial_status(partial_on_dims);
200+
201+
VLOG(4) << "x_axes: " << operands[0] << " out_axes: " << right;
202+
LOG_SPMD_INPUT(x);
203+
VLOG(4) << "out";
204+
VLOG(4) << "dist_attr: [" << out_dist_attr_dst.to_string() << "]";
205+
206+
std::vector<TensorDistAttr> input_dist_attrs;
207+
input_dist_attrs.push_back(x_dist_attr_dst);
208+
return {{input_dist_attrs}, {out_dist_attr_dst}};
209+
} else {
210+
// double operands
211+
DistMetaTensor x = inputs[0];
212+
DistMetaTensor y = inputs[1];
213+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
214+
EXTRACT_SHAPE_AND_DIST_ATTR(y);
215+
std::vector<int64_t> x_dims_mapping(x_dims_mapping_src);
216+
std::vector<int64_t> y_dims_mapping(y_dims_mapping_src);
217+
218+
if (IsEinsumOuter(operands, right)) {
219+
ConstraintOnOuter(x_dist_attr_src,
220+
y_dist_attr_src,
221+
x_ndim,
222+
y_ndim,
223+
&x_dims_mapping,
224+
&y_dims_mapping);
225+
}
226+
VLOG(6) << "EinsumInferSpmd InferForward Inputs: "
227+
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
228+
<< str_join(x_dims_mapping) << "], Y shape: [" << str_join(y_shape)
229+
<< "], y_dims_mapping: [" << str_join(y_dims_mapping);
230+
231+
// Step1: Sharding Propagation
232+
// Step1.1: Merge input shardings
233+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
234+
ShardingMergeForTensors(
235+
{{operands[0], x_dims_mapping}, {operands[1], y_dims_mapping}});
236+
237+
// Step1.2: Infer output dims mapping
238+
TensorDistAttr x_dist_attr_dst =
239+
CopyTensorDistAttrForOutput(x_dist_attr_src);
240+
TensorDistAttr y_dist_attr_dst =
241+
CopyTensorDistAttrForOutput(y_dist_attr_src);
242+
x_dist_attr_dst.set_dims_mapping(
243+
GetDimsMappingForAxes(operands[0], axis_to_dim_map));
244+
y_dist_attr_dst.set_dims_mapping(
245+
GetDimsMappingForAxes(operands[1], axis_to_dim_map));
246+
247+
std::vector<int64_t> fake_output_shape(right.size(), 1);
248+
TensorDistAttr out_dist_attr_dst(fake_output_shape);
249+
out_dist_attr_dst.set_process_mesh(x_dist_attr_src.process_mesh());
250+
out_dist_attr_dst.set_dims_mapping(
251+
GetDimsMappingForAxes(right, axis_to_dim_map));
252+
253+
// Step2: Handle Partial
254+
// Step2.1 Output Partial
255+
std::vector<int64_t> partial_on_dims =
256+
ResoluteOutputPartialDimension(axis_to_dim_map, right);
257+
out_dist_attr_dst.set_partial_status(partial_on_dims);
258+
259+
VLOG(4) << "x_axes: " << operands[0] << " y_axes: " << operands[1]
260+
<< " out_axes: " << right;
261+
LOG_SPMD_INPUT(x);
262+
LOG_SPMD_INPUT(y);
263+
VLOG(4) << "out";
264+
VLOG(4) << "dist_attr: [" << out_dist_attr_dst.to_string() << "]";
265+
266+
std::vector<TensorDistAttr> input_dist_attrs;
267+
input_dist_attrs.push_back(x_dist_attr_dst);
268+
input_dist_attrs.push_back(y_dist_attr_dst);
269+
270+
return {{input_dist_attrs}, {out_dist_attr_dst}};
271+
}
272+
}
273+
274+
SpmdInfo EinsumGradInferSpmd(const std::vector<DistMetaTensor>& inputs,
275+
const std::vector<DistMetaTensor>& inner_cache,
276+
const DistMetaTensor& out_grad,
277+
const std::string& equation) {
278+
PADDLE_ENFORCE_LE(
279+
inputs.size(),
280+
2,
281+
common::errors::InvalidArgument(
282+
"EinsumOp only support len(operands) between (0, 2]. Use "
283+
"opt_einsum first to convert multi-variable to binary-variable."));
284+
285+
std::vector<std::string> operands;
286+
std::string right;
287+
// ellipsis labels are already parsed in python API (einsum_v2)
288+
ParseEinsumEquation(equation, &operands, &right);
289+
// diagonal case
290+
ConstraintOnDiagLabel(&operands, &right);
291+
292+
EXTRACT_SHAPE_AND_DIST_ATTR(out_grad);
293+
if (inputs.size() == 1) {
294+
// single operand
295+
DistMetaTensor x = inputs[0];
296+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
297+
298+
// For reduce label type in equation "right->left" used in backward
299+
// calculation, the gradient on those axes are tiled and copied, so we can
300+
// just copy the dims_mapping on those axes from input to input_grad.
301+
// Therefore we also merge the input axes here.
302+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
303+
ShardingMergeForTensors({{operands[0], x_dims_mapping_src},
304+
{right, out_grad_dims_mapping_src}});
305+
306+
TensorDistAttr x_dist_attr_dst =
307+
CopyTensorDistAttrForOutput(x_dist_attr_src);
308+
x_dist_attr_dst.set_dims_mapping(
309+
GetDimsMappingForAxes(operands[0], axis_to_dim_map));
310+
311+
TensorDistAttr out_grad_dist_attr_dst(out_grad_dist_attr_src);
312+
out_grad_dist_attr_dst.set_dims_mapping(
313+
GetDimsMappingForAxes(right, axis_to_dim_map));
314+
315+
std::vector<TensorDistAttr> input_dist_attrs;
316+
input_dist_attrs.push_back(x_dist_attr_dst);
317+
return {{input_dist_attrs, out_grad_dist_attr_dst}, {input_dist_attrs}};
318+
} else {
319+
// double operands
320+
DistMetaTensor x = inputs[0];
321+
DistMetaTensor y = inputs[1];
322+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
323+
EXTRACT_SHAPE_AND_DIST_ATTR(y);
324+
std::vector<int64_t> x_dims_mapping(x_dims_mapping_src);
325+
std::vector<int64_t> y_dims_mapping(y_dims_mapping_src);
326+
std::vector<int64_t> out_grad_dims_mapping(out_grad_dims_mapping_src);
327+
328+
if (IsEinsumOuter(operands, right)) {
329+
ConstraintOnOuter(x_dist_attr_src,
330+
y_dist_attr_src,
331+
x_ndim,
332+
y_ndim,
333+
&x_dims_mapping,
334+
&y_dims_mapping);
335+
}
336+
// out_grad, x, y
337+
std::unordered_map<std::string, int64_t> fwd_axis_to_dim_map =
338+
ShardingMergeForTensors(
339+
{{operands[0], x_dims_mapping}, {operands[1], y_dims_mapping}});
340+
out_grad_dims_mapping = GetDimsMappingForAxes(right, fwd_axis_to_dim_map);
341+
TensorDistAttr out_grad_dist_attr_dst =
342+
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
343+
out_grad_dist_attr_dst.set_dims_mapping(
344+
GetDimsMappingForAxes(right, fwd_axis_to_dim_map));
345+
TensorDistAttr x_dist_attr_dst =
346+
CopyTensorDistAttrForOutput(x_dist_attr_src);
347+
x_dist_attr_dst.set_dims_mapping(
348+
GetDimsMappingForAxes(operands[0], fwd_axis_to_dim_map));
349+
TensorDistAttr y_dist_attr_dst =
350+
CopyTensorDistAttrForOutput(y_dist_attr_src);
351+
y_dist_attr_dst.set_dims_mapping(
352+
GetDimsMappingForAxes(operands[1], fwd_axis_to_dim_map));
353+
354+
// For reduce label type in equation "left[1], right->left[0]" and "right,
355+
// left[0]->left[1]" used in backward calculation, the gradient on those
356+
// axes are tiled and copied, so we can just copy the dims_mapping on those
357+
// axes from input to input_grad. Therefore we just copy the fwd inferred
358+
// input_dist_attr for input_grad_dist_attr and then handle partial.
359+
360+
// dx = einsum(y, d_out)
361+
TensorDistAttr x_grad_dist_attr_dst = TensorDistAttr(x_dist_attr_dst);
362+
std::unordered_map<std::string, int64_t> axis_to_dim_map_for_dx =
363+
ShardingMergeForTensors(
364+
{{operands[1], y_dims_mapping}, {right, out_grad_dims_mapping}});
365+
// Handle Partial for dx
366+
std::vector<int64_t> partial_on_dx_dims =
367+
ResoluteOutputPartialDimension(axis_to_dim_map_for_dx, operands[0]);
368+
x_grad_dist_attr_dst.set_partial_status(partial_on_dx_dims);
369+
370+
// dy = einsum(d_out, x)
371+
TensorDistAttr y_grad_dist_attr_dst = TensorDistAttr(y_dist_attr_dst);
372+
std::unordered_map<std::string, int64_t> axis_to_dim_map_for_dy =
373+
ShardingMergeForTensors(
374+
{{right, out_grad_dims_mapping}, {operands[0], x_dims_mapping}});
375+
// Handle Partial for dy
376+
std::vector<int64_t> partial_on_dy_dims =
377+
ResoluteOutputPartialDimension(axis_to_dim_map_for_dy, operands[1]);
378+
y_grad_dist_attr_dst.set_partial_status(partial_on_dy_dims);
379+
380+
std::vector<TensorDistAttr> input_dist_attrs;
381+
input_dist_attrs.push_back(x_dist_attr_dst);
382+
input_dist_attrs.push_back(y_dist_attr_dst);
383+
std::vector<TensorDistAttr> input_grad_dist_attrs;
384+
input_grad_dist_attrs.push_back(x_grad_dist_attr_dst);
385+
input_grad_dist_attrs.push_back(y_grad_dist_attr_dst);
386+
return {{input_dist_attrs, out_grad_dist_attr_dst},
387+
{input_grad_dist_attrs}};
388+
}
389+
}
390+
} // namespace phi::distributed

0 commit comments

Comments
 (0)