Skip to content

Commit

Permalink
Make lookup table warp size aware
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#25926

Differential Revision: D17286446

Pulled By: bddppq

fbshipit-source-id: d25515f25f9df309a08ae7f948bb6a087e45134e
  • Loading branch information
iotamudelta authored and facebook-github-bot committed Sep 10, 2019
1 parent 3680cef commit 618804f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions aten/src/THCUNN/generic/LookupTable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#else

#include <thrust/iterator/constant_iterator.h>
#include <c10/macros/Macros.h>

void THNN_(LookupTable_accGradParameters)(
THCState *state,
Expand Down Expand Up @@ -36,15 +37,14 @@ void THNN_(LookupTable_accGradParameters)(
cudaStream_t stream = THCState_getCurrentStream(state);

if (numel <= 768 && !scaleGradByFreq) {
const int WARP_SIZE = 32;
const int BLOCKDIMY = 32;
dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE));
dim3 block(WARP_SIZE, BLOCKDIMY);
dim3 grid(THCCeilDiv(stride, (int64_t)C10_WARP_SIZE));
dim3 block(C10_WARP_SIZE, BLOCKDIMY);

cunn_LookupTable_accGradParametersKernelByFeature<scalar_t, accreal>
<<<grid,
block,
sizeof(accreal)*WARP_SIZE*BLOCKDIMY + sizeof(int)*WARP_SIZE*BLOCKDIMY,
sizeof(accreal)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
stream>>>
(THCIndexTensor_(data)(state, input),
THCTensor_(data)(state, gradOutput),
Expand Down

0 comments on commit 618804f

Please sign in to comment.