-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.52】 为 Paddle 新增 squeeze 和 unsqueeze 的 spmd 切分推导规则 #57877
Changes from 1 commit
e326aee
148925a
fd3f1db
4012a49
4aff5eb
9e9140f
6c2f23f
b727efe
1efc51d
98af5a4
34b7024
32e0ed8
80c07b2
f13fcdd
8acaa6d
504006f
97452fb
c129162
63983b9
1fbb2cf
ea1e6fc
9f95328
f57ee3a
5301bdc
6114774
dd3d96c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,74 +29,59 @@ namespace distributed { | |
|
||
using phi::distributed::auto_parallel::str_join; | ||
|
||
std::vector<DimTrans*> MakeSqueezeDimTransWithoutAxis( | ||
const std::vector<int64_t>& x_shape, std::vector<int64_t>* out_shape) { | ||
std::vector<DimTrans*> ret; | ||
|
||
void MakeSqueezeDimTransWithoutAxis(const std::vector<int64_t>& x_shape, | ||
std::vector<int64_t>* out_shape, | ||
std::vector<DimTrans*>* trans) { | ||
for (int64_t i = 0, n = static_cast<int64_t>(x_shape.size()); i < n; i++) { | ||
if (x_shape[i] != 1) { | ||
ret.emplace_back(new InputDim(i)); | ||
trans->emplace_back(new InputDim(i)); | ||
out_shape->emplace_back(x_shape[i]); | ||
} | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
std::vector<DimTrans*> MakeSqueezeDimTransWithAxis( | ||
const std::vector<int64_t>& x_shape, | ||
std::vector<int64_t>* out_shape, | ||
const std::vector<int64_t>& axis) { | ||
std::vector<DimTrans*> ret; | ||
|
||
void MakeSqueezeDimTransWithAxis(const std::vector<int64_t>& x_shape, | ||
std::vector<int64_t>* out_shape, | ||
const std::vector<int64_t>& axis, | ||
std::vector<DimTrans*>* trans) { | ||
for (int64_t i = 0, n = static_cast<int64_t>(x_shape.size()); i < n; i++) { | ||
ret.emplace_back(new InputDim(i)); | ||
trans->emplace_back(new InputDim(i)); | ||
out_shape->emplace_back(x_shape[i]); | ||
} | ||
|
||
for (int64_t i = 0, n = static_cast<int64_t>(axis.size()); i < n; i++) { | ||
if (x_shape[axis[i]] == 1) { | ||
ret.erase(ret.begin() + axis[i]); | ||
trans->erase(trans->begin() + axis[i]); | ||
out_shape->erase(out_shape->begin() + axis[i]); | ||
} | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
std::vector<DimTrans*> MakeSqueezeDimTransReverseWithoutAxis( | ||
const std::vector<int64_t>& x_shape) { | ||
std::vector<DimTrans*> ret; | ||
|
||
void MakeSqueezeDimTransReverseWithoutAxis(const std::vector<int64_t>& x_shape, | ||
std::vector<DimTrans*>* trans) { | ||
for (int64_t i = 0, j = 0, n = static_cast<int64_t>(x_shape.size()); i < n; | ||
i++) { | ||
if (x_shape[i] != 1) { | ||
ret.emplace_back(new InputDim(j++)); | ||
trans->emplace_back(new InputDim(j++)); | ||
} else { | ||
ret.emplace_back(new Singleton()); | ||
trans->emplace_back(new Singleton()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 改成 make_shared There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
std::vector<DimTrans*> MakeSqueezeDimTransReverseWithAxis( | ||
const std::vector<int64_t>& x_shape, | ||
const std::vector<int64_t>& out_shape, | ||
const std::vector<int64_t>& axis) { | ||
std::vector<DimTrans*> ret; | ||
|
||
void MakeSqueezeDimTransReverseWithAxis(const std::vector<int64_t>& x_shape, | ||
const std::vector<int64_t>& out_shape, | ||
const std::vector<int64_t>& axis, | ||
std::vector<DimTrans*>* trans) { | ||
for (int64_t i = 0, n = static_cast<int64_t>(out_shape.size()); i < n; i++) { | ||
ret.emplace_back(new InputDim(i)); | ||
trans->emplace_back(new InputDim(i)); | ||
} | ||
|
||
for (int64_t i = 0, n = static_cast<int64_t>(axis.size()); i < n; i++) { | ||
if (x_shape[axis[i]] == 1) { | ||
ret.emplace(ret.begin() + axis[i], new Singleton()); | ||
trans->emplace(trans->begin() + axis[i], new Singleton()); | ||
} | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
bool SqueezeCompare(const int64_t& a, const int64_t& b) { return a > b; } | ||
|
@@ -125,7 +110,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, | |
std::vector<int64_t> out_shape; | ||
|
||
if (static_cast<int64_t>(axis.size()) == 0) { | ||
trans = MakeSqueezeDimTransWithoutAxis(x_shape, &out_shape); | ||
MakeSqueezeDimTransWithoutAxis(x_shape, &out_shape, &trans); | ||
} else { | ||
std::vector<int64_t> axis_copy(axis); | ||
for (int64_t i = 0, n = static_cast<int64_t>(axis_copy.size()); i < n; | ||
|
@@ -135,7 +120,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, | |
} | ||
} | ||
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeCompare); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 去掉自己定义的比较函数 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
trans = MakeSqueezeDimTransWithAxis(x_shape, &out_shape, axis_copy); | ||
MakeSqueezeDimTransWithAxis(x_shape, &out_shape, axis_copy, &trans); | ||
} | ||
|
||
// Step2: Infer the dims mapping of input (if reshard is | ||
|
@@ -194,7 +179,7 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, | |
std::vector<DimTrans*> trans; | ||
|
||
if (static_cast<int64_t>(axis.size()) == 0) { | ||
trans = MakeSqueezeDimTransReverseWithoutAxis(x_shape); | ||
MakeSqueezeDimTransReverseWithoutAxis(x_shape, &trans); | ||
} else { | ||
std::vector<int64_t> axis_copy(axis); | ||
for (int64_t i = 0, n = static_cast<int64_t>(axis_copy.size()); i < n; | ||
|
@@ -204,7 +189,7 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, | |
} | ||
} | ||
std::sort(axis_copy.begin(), axis_copy.end(), SqueezeReverseCompare); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用标准库自带的比较函数 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
trans = MakeSqueezeDimTransReverseWithAxis(x_shape, out_shape, axis_copy); | ||
MakeSqueezeDimTransReverseWithAxis(x_shape, out_shape, axis_copy, &trans); | ||
} | ||
|
||
// Step2: Infer the dims mapping of input with | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该可以写得更简单些,先算出输出的size,然后遍历输出的维度,如果某一维在 axis 里且 shape 是 1就用Singleton,否则用 InputDim。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done