-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[pnorm] optimize p_norm for special cases #37685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
…op for flexible call.
@Avin0323 你好,辛苦看下benchmark的CI。这个 PR 改动了 p_norm 的代码,优化了特殊 shape 的性能。从 CI 结果看 p_norm 的测试耗时都降低了。另外,还修改了 cmake 文件,辛苦 review 一下。 |
HOSTDEVICE explicit inline NonzeroFunctor(int n) {} | ||
template <typename T> | ||
HOSTDEVICE inline T operator()(const T& x) const { | ||
return static_cast<T>(static_cast<double>(x) != 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么先 static_cast<double>(x)
cast 为double?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里保留了原始实现
paddle/fluid/operators/p_norm_op.cu
Outdated
auto xdim = in_x->dims(); | ||
auto ndim = out_norm->dims(); | ||
float porder = ctx.Attr<float>("porder"); | ||
int axis = ctx.Attr<int>("axis"); | ||
bool asvector = ctx.Attr<bool>("asvector"); | ||
if (axis < 0) axis = xdim.size() + axis; | ||
int pre, n, post; | ||
GetDims(xdim, axis, &pre, &n, &post, asvector); | ||
std::vector<int> reduce_axis = {axis}; | ||
|
||
auto& dev_ctx = ctx.cuda_device_context(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dev_ctx
no usage in the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,晚点提PR删掉
} else { | ||
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post, | ||
porder, norm); | ||
framework::Tensor tmp_x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记一下todo,这里的 tmp_x
需要尽早去掉。运行时显存占用提升很多。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO 好的,已提交卡片
auto negs = dx->constant(static_cast<T>(-1.)); | ||
auto zeros = dx->constant(static_cast<T>(0.)); | ||
auto positives = (*x) > zeros; | ||
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里反向,都是走的eigen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
计算是eigen tensor
@@ -260,32 +254,38 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> { | |||
float porder = ctx.Attr<float>("porder"); | |||
T eps = static_cast<T>(ctx.Attr<float>("epsilon")); | |||
int axis = ctx.Attr<int>("axis"); | |||
bool reduce_all = ((axis < 0) || (in_norm->numel() == 1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis < 0
是对应 reduce_all
吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
paddle/fluid/operators/p_norm_op.cu
Outdated
bool asvector = ctx.Attr<bool>("asvector"); | ||
if (axis < 0) axis = xdim.size() + axis; | ||
int pre, n, post; | ||
GetDims(xdim, axis, &pre, &n, &post, asvector); | ||
const std::vector<int> dims = {axis}; | ||
|
||
auto& dev_ctx = ctx.cuda_device_context(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dev_ctx
是否还有使用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for PR-CI-OP-benchmark and changes of unity_build_rule.cmake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
Optimize p_norm for two kinds of special cases:
(1) shape=[2, 1000, 1000], reduce axis=0
(2) shape=[1, 2000000, 1], reduce axis=1
The original version is paddlepaddle-gpu == 2.2.1. The
Time
denotes seconds per 1k steps.