Skip to content

Commit fd4d2a9

Browse files
committed
apply suggestion
1 parent 23f08d7 commit fd4d2a9

File tree

1 file changed

+66
-78
lines changed

1 file changed

+66
-78
lines changed

paddle/phi/infermeta/spmd_rules/einsum.cc

Lines changed: 66 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@ void ParseEinsumEquation(const std::string& equation,
3535

3636
void ConstraintOnDiagLabel(std::vector<std::string>* operands,
3737
std::string* output) {
38-
// For fwd spmd rule, only those diagonal labels in output should not be
39-
// sharded. But for bwd spmd rule, input and output are switched. So we simply
40-
// set the spmd rule here to replace all diagonal labels as 1.
38+
// Empirically, for fwd calculation, only those diagonal labels in output
39+
// should not be sharded. e.g. iji->ii (diag), 'i' cannot be sharded;
40+
// e.g. iji->i (trace), 'i' can be sharded.
41+
// But during bwd calculation, input and output are switched.
42+
// e.g. in the 'trace' case above when calculating x_grad, it will use
43+
// i->ii, so 'i' cannot be sharded.
44+
// Thus we simply set the spmd rule here to replace all diagonal labels as 1.
4145

4246
// find diagonal labels
4347
std::unordered_map<char, int> char_count;
@@ -76,11 +80,9 @@ void ConstraintOnDiagLabel(std::vector<std::string>* operands,
7680
}
7781
}
7882

79-
bool IsEinsumOuter(const std::string& equation) {
80-
std::vector<std::string> inputs;
81-
std::string output;
82-
ParseEinsumEquation(equation, &inputs, &output);
83-
83+
bool IsEinsumOuter(const std::vector<std::string>& inputs,
84+
const std::string& output) {
85+
// Outer case: e.g. i, j -> ij; ij, kl -> ijkl
8486
if (inputs.size() != 2) {
8587
return false;
8688
}
@@ -105,6 +107,48 @@ bool IsEinsumOuter(const std::string& equation) {
105107
return true;
106108
}
107109

110+
void ConstraintOnOuter(const phi::distributed::TensorDistAttr& x_attr,
111+
const phi::distributed::TensorDistAttr& y_attr,
112+
int x_ndim,
113+
int y_ndim,
114+
std::vector<int64_t>* x_dims_mapping,
115+
std::vector<int64_t>* y_dims_mapping) {
116+
// For outer operation, only one operand and one dimension can be sharded
117+
// todo: if multiple dimensions are requested to be sharded, decide which
118+
// operand and which dimension to be sharded could be better
119+
120+
// we simply choose the first operand requested to be sharded and the
121+
// first dimension requested to be sharded here
122+
if (x_attr.is_shard()) {
123+
bool meet_shard_axis = false;
124+
for (int i = 0; i < x_ndim; ++i) {
125+
if ((*x_dims_mapping)[i] != -1) {
126+
meet_shard_axis = true;
127+
continue;
128+
}
129+
if (meet_shard_axis) {
130+
(*x_dims_mapping)[i] = -1;
131+
}
132+
}
133+
// reset y_dims_mapping to all replicated
134+
for (int i = 0; i < y_ndim; ++i) {
135+
(*y_dims_mapping)[i] = -1;
136+
}
137+
} else if (y_attr.is_shard()) {
138+
bool meet_shard_axis = false;
139+
for (int i = 0; i < y_ndim; ++i) {
140+
if ((*y_dims_mapping)[i] != -1) {
141+
meet_shard_axis = true;
142+
continue;
143+
}
144+
if (meet_shard_axis) {
145+
(*y_dims_mapping)[i] = -1;
146+
}
147+
}
148+
// no need to reset x_dims_mapping
149+
}
150+
}
151+
108152
SpmdInfo EinsumInferSpmd(const std::vector<DistMetaTensor>& inputs,
109153
const std::string& equation) {
110154
PADDLE_ENFORCE_LE(
@@ -172,43 +216,14 @@ SpmdInfo EinsumInferSpmd(const std::vector<DistMetaTensor>& inputs,
172216
std::vector<int64_t> x_dims_mapping(x_dims_mapping_src);
173217
std::vector<int64_t> y_dims_mapping(y_dims_mapping_src);
174218

175-
// outer case
176-
// for outer operation, only one operand and one dimension can be sharded
177-
// todo: if multiple dimensions are requested to be sharded, decide which
178-
// operand and which dimension to be sharded could be better
179-
if (IsEinsumOuter(equation)) {
180-
// we simply choose the first operand requested to be sharded and the
181-
// first dimension requested to be sharded here
182-
if (x_dist_attr_src.is_shard()) {
183-
bool meet_shard_axis = false;
184-
for (int i = 0; i < x_ndim; ++i) {
185-
if (x_dims_mapping[i] != -1) {
186-
meet_shard_axis = true;
187-
continue;
188-
}
189-
if (meet_shard_axis) {
190-
x_dims_mapping[i] = -1;
191-
}
192-
}
193-
// reset y_dims_mapping to all replicated
194-
for (int i = 0; i < y_ndim; ++i) {
195-
y_dims_mapping[i] = -1;
196-
}
197-
} else if (y_dist_attr_src.is_shard()) {
198-
bool meet_shard_axis = false;
199-
for (int i = 0; i < y_ndim; ++i) {
200-
if (y_dims_mapping[i] != -1) {
201-
meet_shard_axis = true;
202-
continue;
203-
}
204-
if (meet_shard_axis) {
205-
y_dims_mapping[i] = -1;
206-
}
207-
}
208-
// no need to reset x_dims_mapping
209-
}
219+
if (IsEinsumOuter(operands, right)) {
220+
ConstraintOnOuter(x_dist_attr_src,
221+
y_dist_attr_src,
222+
x_ndim,
223+
y_ndim,
224+
&x_dims_mapping,
225+
&y_dims_mapping);
210226
}
211-
212227
VLOG(6) << "EinsumInferSpmd InferForward Inputs: "
213228
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
214229
<< str_join(x_dims_mapping) << "], Y shape: [" << str_join(y_shape)
@@ -312,40 +327,13 @@ SpmdInfo EinsumGradInferSpmd(const std::vector<DistMetaTensor>& inputs,
312327
std::vector<int64_t> y_dims_mapping(y_dims_mapping_src);
313328
std::vector<int64_t> out_grad_dims_mapping(out_grad_dims_mapping_src);
314329

315-
// outer case
316-
// for outer operation, only one operand and one dimension can be sharded
317-
// todo: if multiple dimensions are requested to be sharded, decide which
318-
// operand and which dimension to be sharded could be better
319-
if (IsEinsumOuter(equation)) {
320-
// inputs
321-
if (x_dist_attr_src.is_shard()) {
322-
bool meet_shard_axis = false;
323-
for (int i = 0; i < x_ndim; ++i) {
324-
if (x_dims_mapping[i] != -1) {
325-
meet_shard_axis = true;
326-
continue;
327-
}
328-
if (meet_shard_axis) {
329-
x_dims_mapping[i] = -1;
330-
}
331-
}
332-
// reset y_dims_mapping to all replicated
333-
for (int i = 0; i < y_ndim; ++i) {
334-
y_dims_mapping[i] = -1;
335-
}
336-
} else if (y_dist_attr_src.is_shard()) {
337-
bool meet_shard_axis = false;
338-
for (int i = 0; i < y_ndim; ++i) {
339-
if (y_dims_mapping[i] != -1) {
340-
meet_shard_axis = true;
341-
continue;
342-
}
343-
if (meet_shard_axis) {
344-
y_dims_mapping[i] = -1;
345-
}
346-
}
347-
// no need to reset x_dims_mapping
348-
}
330+
if (IsEinsumOuter(operands, right)) {
331+
ConstraintOnOuter(x_dist_attr_src,
332+
y_dist_attr_src,
333+
x_ndim,
334+
y_ndim,
335+
&x_dims_mapping,
336+
&y_dims_mapping);
349337
}
350338
// out_grad, x, y
351339
std::unordered_map<std::string, int64_t> fwd_axis_to_dim_map =

0 commit comments

Comments
 (0)