-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Speed up elemwise grad #8402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up elemwise grad #8402
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that elementwise_add_grad can be implemented by matrix multiplication. Maybe it will be faster.
} | ||
} | ||
struct IdentityGrad { | ||
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add inline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, inline is decided by the compiler.
do { | ||
int x_offset = i * w + j; | ||
if (dx) { | ||
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether this will be faster than before. For elementwise_add_grad, the computation of dx only use dout
, but line 374 will cause the data(x,y,out) which is useless to be transferred from the graphics memory to the register.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvcc
can optimize the memory access if the functor is not using the variable.
I just check this by reading the generated PTX file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just check this by reading the generated PTX file.
cool...
shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); | ||
} | ||
i += 1024; | ||
} while (i < h); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 378~380 is confusing.
It seems 1024 should be blockDim.x
.
… feature/fix_elemwise_grad
shm[tid] = 0; | ||
|
||
do { | ||
int x_offset = i * w + j; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data(x, dx, dout) access is not continuous. This may have an impact on Performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. However, this will make the reduction easier. There could be a more effective implementation.
|
||
while (true) { | ||
int i = ttid / post; | ||
int k = ttid % post; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The division is very time consuming, it is recommended to multiply.
float inv_post = 1.0/post;
while(true){
int i = ttid * inv_post;
int k = ttid - i * post;
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure which implementation is faster, the multiplication between float values or the division between integers. However, it should not cost too much time these lines since it is not the main logic of the method.
int k = ttid % post; | ||
if (i >= pre) break; | ||
|
||
int x_offset = i * n * post + j * post + k; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int x_offset = i * n * post + j * post + k;
==>
int x_offset = (i * n + j) * post + k;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compiler should optimize this equation.
7e3ae7e
to
cad4d76
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work. In my testing benchmark, the time percental of elementwise grad goes from 85.8% to 2.3%.
The alternative solution suggested by @chengduoZH is very valuable. But I would suggest merging this PR first. We can make another PR if later on this 2.3% becomes the bottleneck. 👍
Fix #7862