Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.52】 为 Paddle 新增 squeeze 和 unsqueeze 的 spmd 切分推导规则 #57877

Merged
merged 26 commits into from
Nov 27, 2023
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e326aee
Add spmd segmentation and derivation rules for squeeze to Paddle
WintersMontagne10335 Oct 4, 2023
148925a
Add spmd segmentation derivation rule for unsqueeze to Paddle
WintersMontagne10335 Oct 4, 2023
fd3f1db
fix bugs
WintersMontagne10335 Oct 4, 2023
4012a49
fix bugs
WintersMontagne10335 Oct 5, 2023
4aff5eb
fix bugs
WintersMontagne10335 Oct 5, 2023
9e9140f
fix bugs
WintersMontagne10335 Oct 5, 2023
6c2f23f
Add unit test code
WintersMontagne10335 Oct 6, 2023
b727efe
modify squeeze.cc and CMakeLists.txt
WintersMontagne10335 Oct 13, 2023
1efc51d
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 13, 2023
98af5a4
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 13, 2023
34b7024
write separate rules
WintersMontagne10335 Oct 14, 2023
32e0ed8
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 14, 2023
80c07b2
fix bugs
WintersMontagne10335 Oct 14, 2023
f13fcdd
fix bugs
WintersMontagne10335 Oct 15, 2023
8acaa6d
fix bugs
WintersMontagne10335 Oct 15, 2023
504006f
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 20, 2023
97452fb
remove unsqueeze spmd rule
WintersMontagne10335 Oct 21, 2023
c129162
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 21, 2023
63983b9
modified: test/auto_parallel/spmd_rules/test_squeeze_rule.py
WintersMontagne10335 Oct 21, 2023
1fbb2cf
re-run CI
WintersMontagne10335 Oct 21, 2023
ea1e6fc
fix bugs
WintersMontagne10335 Oct 28, 2023
9f95328
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Oct 28, 2023
f57ee3a
Merge remote-tracking branch 'upstream/develop' into winters009
WintersMontagne10335 Nov 24, 2023
5301bdc
modify pointer to smart pointer
WintersMontagne10335 Nov 24, 2023
6114774
fix bugs
WintersMontagne10335 Nov 24, 2023
dd3d96c
fix bugs
WintersMontagne10335 Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix bugs
  • Loading branch information
WintersMontagne10335 committed Nov 24, 2023
commit dd3d96cfd3218d6f4d2cec8b1563897c643be5c6
12 changes: 5 additions & 7 deletions paddle/phi/infermeta/spmd_rules/squeeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void MakeSqueezeDimTransWithAxis(
if (x_shape[i] == 1) {
auto it = find(axis.begin(), axis.end(), i);
if (it == axis.end()) {
trans->emplace_back(new Singleton());
trans->emplace_back(std::make_shared<Singleton>());
out_shape->emplace_back(1);
}
} else {
Expand All @@ -68,7 +68,7 @@ void MakeSqueezeDimTransReverseWithoutAxis(
if (x_shape[i] != 1) {
trans->emplace_back(std::make_shared<InputDim>(j++));
} else {
trans->emplace_back(new Singleton());
trans->emplace_back(std::make_shared<Singleton>());
}
}
}
Expand All @@ -81,7 +81,7 @@ void MakeSqueezeDimTransReverseWithAxis(
for (int64_t i = 0, j = 0, n = static_cast<int64_t>(x_shape.size()); i < n;
i++) {
if (x_shape[i] == 1) {
trans->emplace_back(new Singleton());
trans->emplace_back(std::make_shared<Singleton>());

auto it = find(axis.begin(), axis.end(), i);
if (it == axis.end()) {
Expand Down Expand Up @@ -144,8 +144,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
<< "] Out shape: [" << str_join(out_shape) << "]";
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
std::shared_ptr<DimTrans> t = trans[i];
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
VLOG(4) << "\tOut axis[" << i << "]: " << trans[i]->to_string();
}
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0])
Expand Down Expand Up @@ -210,8 +209,7 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
<< "] X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Transformation from output to input:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
std::shared_ptr<DimTrans> t = trans[i];
VLOG(4) << "\tX axis[" << i << "]: " << t->to_string();
VLOG(4) << "\tX axis[" << i << "]: " << trans[i]->to_string();
}
VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] "
<< "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
Expand Down