@@ -35,9 +35,13 @@ void ParseEinsumEquation(const std::string& equation,
3535
3636void ConstraintOnDiagLabel (std::vector<std::string>* operands,
3737 std::string* output) {
38- // For fwd spmd rule, only those diagonal labels in output should not be
39- // sharded. But for bwd spmd rule, input and output are switched. So we simply
40- // set the spmd rule here to replace all diagonal labels as 1.
38+ // Empirically, for fwd calculation, only those diagonal labels in output
39+ // should not be sharded. e.g. iji->ii (diag), 'i' cannot be sharded;
40+ // e.g. iji->i (trace), 'i' can be sharded.
41+ // But during bwd calculation, input and output are switched.
42+ // e.g. in the 'trace' case above when calculating x_grad, it will use
43+ // i->ii, so 'i' cannot be sharded.
44+ // Thus we simply set the spmd rule here to replace all diagonal labels as 1.
4145
4246 // find diagonal labels
4347 std::unordered_map<char , int > char_count;
@@ -76,11 +80,9 @@ void ConstraintOnDiagLabel(std::vector<std::string>* operands,
7680 }
7781}
7882
79- bool IsEinsumOuter (const std::string& equation) {
80- std::vector<std::string> inputs;
81- std::string output;
82- ParseEinsumEquation (equation, &inputs, &output);
83-
83+ bool IsEinsumOuter (const std::vector<std::string>& inputs,
84+ const std::string& output) {
85+ // Outer case: e.g. i, j -> ij; ij, kl -> ijkl
8486 if (inputs.size () != 2 ) {
8587 return false ;
8688 }
@@ -105,6 +107,48 @@ bool IsEinsumOuter(const std::string& equation) {
105107 return true ;
106108}
107109
110+ void ConstraintOnOuter (const phi::distributed::TensorDistAttr& x_attr,
111+ const phi::distributed::TensorDistAttr& y_attr,
112+ int x_ndim,
113+ int y_ndim,
114+ std::vector<int64_t >* x_dims_mapping,
115+ std::vector<int64_t >* y_dims_mapping) {
116+ // For outer operation, only one operand and one dimension can be sharded
117+ // todo: if multiple dimensions are requested to be sharded, decide which
118+ // operand and which dimension to be sharded could be better
119+
120+ // we simply choose the first operand requested to be sharded and the
121+ // first dimension requested to be sharded here
122+ if (x_attr.is_shard ()) {
123+ bool meet_shard_axis = false ;
124+ for (int i = 0 ; i < x_ndim; ++i) {
125+ if ((*x_dims_mapping)[i] != -1 ) {
126+ meet_shard_axis = true ;
127+ continue ;
128+ }
129+ if (meet_shard_axis) {
130+ (*x_dims_mapping)[i] = -1 ;
131+ }
132+ }
133+ // reset y_dims_mapping to all replicated
134+ for (int i = 0 ; i < y_ndim; ++i) {
135+ (*y_dims_mapping)[i] = -1 ;
136+ }
137+ } else if (y_attr.is_shard ()) {
138+ bool meet_shard_axis = false ;
139+ for (int i = 0 ; i < y_ndim; ++i) {
140+ if ((*y_dims_mapping)[i] != -1 ) {
141+ meet_shard_axis = true ;
142+ continue ;
143+ }
144+ if (meet_shard_axis) {
145+ (*y_dims_mapping)[i] = -1 ;
146+ }
147+ }
148+ // no need to reset x_dims_mapping
149+ }
150+ }
151+
108152SpmdInfo EinsumInferSpmd (const std::vector<DistMetaTensor>& inputs,
109153 const std::string& equation) {
110154 PADDLE_ENFORCE_LE (
@@ -172,43 +216,14 @@ SpmdInfo EinsumInferSpmd(const std::vector<DistMetaTensor>& inputs,
172216 std::vector<int64_t > x_dims_mapping (x_dims_mapping_src);
173217 std::vector<int64_t > y_dims_mapping (y_dims_mapping_src);
174218
175- // outer case
176- // for outer operation, only one operand and one dimension can be sharded
177- // todo: if multiple dimensions are requested to be sharded, decide which
178- // operand and which dimension to be sharded could be better
179- if (IsEinsumOuter (equation)) {
180- // we simply choose the first operand requested to be sharded and the
181- // first dimension requested to be sharded here
182- if (x_dist_attr_src.is_shard ()) {
183- bool meet_shard_axis = false ;
184- for (int i = 0 ; i < x_ndim; ++i) {
185- if (x_dims_mapping[i] != -1 ) {
186- meet_shard_axis = true ;
187- continue ;
188- }
189- if (meet_shard_axis) {
190- x_dims_mapping[i] = -1 ;
191- }
192- }
193- // reset y_dims_mapping to all replicated
194- for (int i = 0 ; i < y_ndim; ++i) {
195- y_dims_mapping[i] = -1 ;
196- }
197- } else if (y_dist_attr_src.is_shard ()) {
198- bool meet_shard_axis = false ;
199- for (int i = 0 ; i < y_ndim; ++i) {
200- if (y_dims_mapping[i] != -1 ) {
201- meet_shard_axis = true ;
202- continue ;
203- }
204- if (meet_shard_axis) {
205- y_dims_mapping[i] = -1 ;
206- }
207- }
208- // no need to reset x_dims_mapping
209- }
219+ if (IsEinsumOuter (operands, right)) {
220+ ConstraintOnOuter (x_dist_attr_src,
221+ y_dist_attr_src,
222+ x_ndim,
223+ y_ndim,
224+ &x_dims_mapping,
225+ &y_dims_mapping);
210226 }
211-
212227 VLOG (6 ) << " EinsumInferSpmd InferForward Inputs: "
213228 << " X shape: [" << str_join (x_shape) << " ], x_dims_mapping: ["
214229 << str_join (x_dims_mapping) << " ], Y shape: [" << str_join (y_shape)
@@ -312,40 +327,13 @@ SpmdInfo EinsumGradInferSpmd(const std::vector<DistMetaTensor>& inputs,
312327 std::vector<int64_t > y_dims_mapping (y_dims_mapping_src);
313328 std::vector<int64_t > out_grad_dims_mapping (out_grad_dims_mapping_src);
314329
315- // outer case
316- // for outer operation, only one operand and one dimension can be sharded
317- // todo: if multiple dimensions are requested to be sharded, decide which
318- // operand and which dimension to be sharded could be better
319- if (IsEinsumOuter (equation)) {
320- // inputs
321- if (x_dist_attr_src.is_shard ()) {
322- bool meet_shard_axis = false ;
323- for (int i = 0 ; i < x_ndim; ++i) {
324- if (x_dims_mapping[i] != -1 ) {
325- meet_shard_axis = true ;
326- continue ;
327- }
328- if (meet_shard_axis) {
329- x_dims_mapping[i] = -1 ;
330- }
331- }
332- // reset y_dims_mapping to all replicated
333- for (int i = 0 ; i < y_ndim; ++i) {
334- y_dims_mapping[i] = -1 ;
335- }
336- } else if (y_dist_attr_src.is_shard ()) {
337- bool meet_shard_axis = false ;
338- for (int i = 0 ; i < y_ndim; ++i) {
339- if (y_dims_mapping[i] != -1 ) {
340- meet_shard_axis = true ;
341- continue ;
342- }
343- if (meet_shard_axis) {
344- y_dims_mapping[i] = -1 ;
345- }
346- }
347- // no need to reset x_dims_mapping
348- }
330+ if (IsEinsumOuter (operands, right)) {
331+ ConstraintOnOuter (x_dist_attr_src,
332+ y_dist_attr_src,
333+ x_ndim,
334+ y_ndim,
335+ &x_dims_mapping,
336+ &y_dims_mapping);
349337 }
350338 // out_grad, x, y
351339 std::unordered_map<std::string, int64_t > fwd_axis_to_dim_map =
0 commit comments