Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.52】 为 Paddle 新增 squeeze 和 unsqueeze 的 spmd 切分推导规则 #57877

Merged
merged 26 commits into from
Nov 27, 2023
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e326aee
Add spmd segmentation and derivation rules for squeeze to Paddle
WintersMontagne10335 Oct 4, 2023
148925a
Add spmd segmentation derivation rule for unsqueeze to Paddle
WintersMontagne10335 Oct 4, 2023
fd3f1db
fix bugs
WintersMontagne10335 Oct 4, 2023
4012a49
fix bugs
WintersMontagne10335 Oct 5, 2023
4aff5eb
fix bugs
WintersMontagne10335 Oct 5, 2023
9e9140f
fix bugs
WintersMontagne10335 Oct 5, 2023
6c2f23f
Add unit test code
WintersMontagne10335 Oct 6, 2023
b727efe
modify squeeze.cc and CMakeLists.txt
WintersMontagne10335 Oct 13, 2023
1efc51d
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 13, 2023
98af5a4
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 13, 2023
34b7024
write separate rules
WintersMontagne10335 Oct 14, 2023
32e0ed8
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 14, 2023
80c07b2
fix bugs
WintersMontagne10335 Oct 14, 2023
f13fcdd
fix bugs
WintersMontagne10335 Oct 15, 2023
8acaa6d
fix bugs
WintersMontagne10335 Oct 15, 2023
504006f
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 20, 2023
97452fb
remove unsqueeze spmd rule
WintersMontagne10335 Oct 21, 2023
c129162
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 21, 2023
63983b9
modified: test/auto_parallel/spmd_rules/test_squeeze_rule.py
WintersMontagne10335 Oct 21, 2023
1fbb2cf
re-run CI
WintersMontagne10335 Oct 21, 2023
ea1e6fc
fix bugs
WintersMontagne10335 Oct 28, 2023
9f95328
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 28, 2023
f57ee3a
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Nov 24, 2023
5301bdc
modify pointer to smart pointer
WintersMontagne10335 Nov 24, 2023
6114774
fix bugs
WintersMontagne10335 Nov 24, 2023
dd3d96c
fix bugs
WintersMontagne10335 Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix bugs
  • Loading branch information
WintersMontagne10335 committed Oct 15, 2023
commit 8acaa6dd7834a70a9cd0697c06415b41f029b4ab
63 changes: 24 additions & 39 deletions paddle/phi/infermeta/spmd_rules/squeeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,74 +29,59 @@ namespace distributed {

using phi::distributed::auto_parallel::str_join;

std::vector<DimTrans*> MakeSqueezeDimTransWithoutAxis(
const std::vector<int64_t>& x_shape, std::vector<int64_t>* out_shape) {
std::vector<DimTrans*> ret;

void MakeSqueezeDimTransWithoutAxis(const std::vector<int64_t>& x_shape,
std::vector<int64_t>* out_shape,
std::vector<DimTrans*>* trans) {
for (int64_t i = 0, n = static_cast<int64_t>(x_shape.size()); i < n; i++) {
if (x_shape[i] != 1) {
ret.emplace_back(new InputDim(i));
trans->emplace_back(new InputDim(i));
out_shape->emplace_back(x_shape[i]);
}
}

return ret;
}

std::vector<DimTrans*> MakeSqueezeDimTransWithAxis(
const std::vector<int64_t>& x_shape,
std::vector<int64_t>* out_shape,
const std::vector<int64_t>& axis) {
std::vector<DimTrans*> ret;

void MakeSqueezeDimTransWithAxis(const std::vector<int64_t>& x_shape,
std::vector<int64_t>* out_shape,
const std::vector<int64_t>& axis,
std::vector<DimTrans*>* trans) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该可以写得更简单些,先算出输出的size,然后遍历输出的维度,如果某一维在 axis 里且 shape 是 1就用Singleton,否则用 InputDim。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int64_t i = 0, n = static_cast<int64_t>(x_shape.size()); i < n; i++) {
ret.emplace_back(new InputDim(i));
trans->emplace_back(new InputDim(i));
out_shape->emplace_back(x_shape[i]);
}

for (int64_t i = 0, n = static_cast<int64_t>(axis.size()); i < n; i++) {
if (x_shape[axis[i]] == 1) {
ret.erase(ret.begin() + axis[i]);
trans->erase(trans->begin() + axis[i]);
out_shape->erase(out_shape->begin() + axis[i]);
}
}

return ret;
}

std::vector<DimTrans*> MakeSqueezeDimTransReverseWithoutAxis(
const std::vector<int64_t>& x_shape) {
std::vector<DimTrans*> ret;

void MakeSqueezeDimTransReverseWithoutAxis(const std::vector<int64_t>& x_shape,
std::vector<DimTrans*>* trans) {
for (int64_t i = 0, j = 0, n = static_cast<int64_t>(x_shape.size()); i < n;
i++) {
if (x_shape[i] != 1) {
ret.emplace_back(new InputDim(j++));
trans->emplace_back(new InputDim(j++));
} else {
ret.emplace_back(new Singleton());
trans->emplace_back(new Singleton());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成 make_shared

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

return ret;
}

std::vector<DimTrans*> MakeSqueezeDimTransReverseWithAxis(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& axis) {
std::vector<DimTrans*> ret;

void MakeSqueezeDimTransReverseWithAxis(const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& axis,
std::vector<DimTrans*>* trans) {
for (int64_t i = 0, n = static_cast<int64_t>(out_shape.size()); i < n; i++) {
ret.emplace_back(new InputDim(i));
trans->emplace_back(new InputDim(i));
}

for (int64_t i = 0, n = static_cast<int64_t>(axis.size()); i < n; i++) {
if (x_shape[axis[i]] == 1) {
ret.emplace(ret.begin() + axis[i], new Singleton());
trans->emplace(trans->begin() + axis[i], new Singleton());
}
}

return ret;
}

bool SqueezeCompare(const int64_t& a, const int64_t& b) { return a > b; }
Expand Down Expand Up @@ -125,7 +110,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
std::vector<int64_t> out_shape;

if (static_cast<int64_t>(axis.size()) == 0) {
trans = MakeSqueezeDimTransWithoutAxis(x_shape, &out_shape);
MakeSqueezeDimTransWithoutAxis(x_shape, &out_shape, &trans);
} else {
std::vector<int64_t> axis_copy(axis);
for (int64_t i = 0, n = static_cast<int64_t>(axis_copy.size()); i < n;
Expand All @@ -135,7 +120,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
}
}
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeCompare);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉自己定义的比较函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

trans = MakeSqueezeDimTransWithAxis(x_shape, &out_shape, axis_copy);
MakeSqueezeDimTransWithAxis(x_shape, &out_shape, axis_copy, &trans);
}

// Step2: Infer the dims mapping of input (if reshard is
Expand Down Expand Up @@ -194,7 +179,7 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
std::vector<DimTrans*> trans;

if (static_cast<int64_t>(axis.size()) == 0) {
trans = MakeSqueezeDimTransReverseWithoutAxis(x_shape);
MakeSqueezeDimTransReverseWithoutAxis(x_shape, &trans);
} else {
std::vector<int64_t> axis_copy(axis);
for (int64_t i = 0, n = static_cast<int64_t>(axis_copy.size()); i < n;
Expand All @@ -204,7 +189,7 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
}
}
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeReverseCompare);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用标准库自带的比较函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

trans = MakeSqueezeDimTransReverseWithAxis(x_shape, out_shape, axis_copy);
MakeSqueezeDimTransReverseWithAxis(x_shape, out_shape, axis_copy, &trans);
}

// Step2: Infer the dims mapping of input with
Expand Down