-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Semi-Auto] Add reshape spmd rule #55177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
a5efe9d to
4370e87
Compare
|
Sorry to inform you that 4370e87's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
|
Sorry to inform you that 7b96d26's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src_size --> src_numel / src_nelem,
src_size is ambiguous with src_shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, use total_elem_num_src now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is redundant to compute the output shape again here.
only DimTrans::Type::SPLIT need maintain the output shape segment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove the computing output parts now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better local_split_shape --> local_axis_size ?
shape: [a,b,c]
axis_size: the value of a or b or c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, rename to "local_splitted_shape_value".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to trick to calculate the input_dims_mapping_dst.
not need to introduce "reshard" into InferSPMD;
directly use shardable vector to remove "sharded" in input_dims_mapping_src.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove the "reshard" word.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too trick to calculate the output_dims_mapping.
should main dim_map_tgt2src, and unshardedable map for output axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove the redundant code. Using "dim_map_tgt2src" will meet some bugs when input_dims_mapping should be set to replicated, so keep "dim_map_src2tgt" now, and it is more intuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: think about print useful info about tensor axes and dims_mapping for debug:
idea1: construct einsum notation for debug and giving corresponding axes between input and output a specific character, therefore user could be notified that those axes are related.
idea2: print out the DimTrans and make the info readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, print the transformation info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unitest should include following cases:
- input axis directly map to output axis
- multiple input axes merge into single output axis, with shard on first input axis
- multiple input axes merge into single output axis, with shard on axis other than first input axis
- single input axis split into multiple output axes, with first output axis dividable
- single input axis split into multiple output axes, with first output axis non-dividable
- multiple input axes transform into multiple output axis, with shard on first input axis/shard on input axis other than the first axis/ first output axis dividable/ first output axis non-dividable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use static global vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here "static" indicates that the global variable can only be used in this file. The global vector is used to store all transformation objects so that we can free them after inferring distributed attributes.
PR types
New features
PR changes
Others
Description
Pcard-70448
Add reshape spmd rule for auto parallel. This rule infers the output's distributed attribute with the following two steps: