Skip to content

Conversation

@pkuzyc
Copy link
Contributor

@pkuzyc pkuzyc commented Jul 5, 2023

PR types

New features

PR changes

Others

Description

Pcard-70448
Add reshape spmd rule for auto parallel. This rule infers the output's distributed attribute with the following two steps:

  1. Compute the transformation from the original shape to the target shape.
  2. Compute the output's distributed attribute according to the transformation from step 1.

@paddle-bot
Copy link

paddle-bot bot commented Jul 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@pkuzyc pkuzyc closed this Jul 5, 2023
@pkuzyc pkuzyc reopened this Jul 5, 2023
@pkuzyc pkuzyc force-pushed the reshape_rule branch 2 times, most recently from a5efe9d to 4370e87 Compare July 12, 2023 09:21
@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Jul 20, 2023

Sorry to inform you that 4370e87's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Aug 1, 2023

Sorry to inform you that 7b96d26's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src_size --> src_numel / src_nelem,
src_size is ambiguous with src_shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, use total_elem_num_src now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is redundant to compute the output shape again here.
only DimTrans::Type::SPLIT need maintain the output shape segment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove the computing output parts now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better local_split_shape --> local_axis_size ?
shape: [a,b,c]
axis_size: the value of a or b or c

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, rename to "local_splitted_shape_value".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to trick to calculate the input_dims_mapping_dst.
not need to introduce "reshard" into InferSPMD;
directly use shardable vector to remove "sharded" in input_dims_mapping_src.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove the "reshard" word.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too trick to calculate the output_dims_mapping.
should main dim_map_tgt2src, and unshardedable map for output axis.

Copy link
Contributor Author

@pkuzyc pkuzyc Aug 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove the redundant code. Using "dim_map_tgt2src" will meet some bugs when input_dims_mapping should be set to replicated, so keep "dim_map_src2tgt" now, and it is more intuitive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: think about print useful info about tensor axes and dims_mapping for debug:

idea1: construct einsum notation for debug and giving corresponding axes between input and output a specific character, therefore user could be notified that those axes are related.

idea2: print out the DimTrans and make the info readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, print the transformation info.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unitest should include following cases:

  1. input axis directly map to output axis
  2. multiple input axes merge into single output axis, with shard on first input axis
  3. multiple input axes merge into single output axis, with shard on axis other than first input axis
  4. single input axis split into multiple output axes, with first output axis dividable
  5. single input axis split into multiple output axes, with first output axis non-dividable
  6. multiple input axes transform into multiple output axis, with shard on first input axis/shard on input axis other than the first axis/ first output axis dividable/ first output axis non-dividable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use static global vector?

Copy link
Contributor Author

@pkuzyc pkuzyc Aug 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here "static" indicates that the global variable can only be used in this file. The global vector is used to store all transformation objects so that we can free them after inferring distributed attributes.

@JZ-LIANG JZ-LIANG merged commit a97b507 into PaddlePaddle:develop Aug 14, 2023
@pkuzyc pkuzyc deleted the reshape_rule branch February 6, 2024 02:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants