Skip to content

Commit 3b7b923

Browse files
csarofeensoumith
authored andcommitted
Fix grid size for batch cat tensor now that getApplyGrid has been changed.
1 parent 80caca4 commit 3b7b923

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

THCTensorMath.cuh

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t
2626
#define CAT_ARRAY_BATCH_SIZE 1024
2727
#define CAT_ARRAY_MAX_INPUT_DIMS 4
2828

29+
inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) {
30+
int curDevice = -1;
31+
cudaGetDevice(&curDevice);
32+
33+
if (curDevice == -1) {
34+
return false;
35+
}
36+
37+
// Assume a reasonable number of SMs if no state is available
38+
int numSM =
39+
state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15;
40+
//X dim of grid for cat array cooperates on a single tensor in the cat.
41+
//Given half of the GPU, full utilization will always occur.
42+
grid = dim3( 2LL * numSM, (long long) nTensors );
43+
44+
return true;
45+
}
46+
2947
// Similar to any other IndexToOffset calculation for copying along a given dimension.
3048
template <typename IndexType, int Dims>
3149
struct CatArrIndexToOffset {
@@ -77,26 +95,36 @@ struct OutputTensorSizeStride {
7795
*
7896
* The most important assumption made is that the input tensors are contiguous.
7997
*/
98+
99+
100+
80101
template <typename T, typename IndexType, int Dims>
81102
__global__ void CatArrayBatchedCopy(
82103
T* output,
83104
CatArrInputTensor<T, IndexType>* inputs,
84105
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
85106
const int concatDim,
86107
IndexType dimStride) {
87-
T* data = inputs[blockIdx.y].input;
88-
IndexType offset = inputs[blockIdx.y].offset;
89-
IndexType dimSize = inputs[blockIdx.y].dimSize;
90-
IndexType nElements = inputs[blockIdx.y].nElements;
91-
IndexType dataOffset = offset * dimStride;
92-
93-
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
94-
linearIndex < nElements;
95-
linearIndex += gridDim.x * blockDim.x) {
108+
109+
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
110+
IndexType nElements = inputs[blockIdx.y].nElements;
111+
112+
if(tid >= nElements) return;
113+
114+
T* data = inputs[blockIdx.y].input;
115+
IndexType offset = inputs[blockIdx.y].offset;
116+
IndexType dimSize = inputs[blockIdx.y].dimSize;
117+
IndexType dataOffset = offset * dimStride;
118+
119+
IndexType stride = gridDim.x * blockDim.x;
120+
121+
while( tid < nElements){
96122
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
97-
os.outputSize, os.outputStride, dimSize, concatDim, linearIndex);
98-
output[dataOffset + elementOffset] = data[linearIndex];
99-
}
123+
os.outputSize, os.outputStride, dimSize, concatDim, tid);
124+
output[dataOffset + elementOffset] = data[tid];
125+
126+
tid += stride;
127+
}
100128
}
101129

102130
#endif

generic/THCTensorMath.cu

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
207207

208208
// Template Declarations for dim = 1, 2, 3, 4
209209
#define HANDLE_CASE(DIMS) \
210-
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
210+
CatArrayBatchedCopy<real, unsigned int, DIMS><<<catGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
211211

212212
// Now we loop
213213
offset = 0;
@@ -243,15 +243,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
243243
// is based on.
244244
dim3 applyBlock = getApplyBlock();
245245

246-
// We also re-use the applyGrid - but note that we use the maximum number of
247-
// elements for a given tensor in this grouping to determine the count
248-
dim3 applyGrid;
249-
getApplyGrid(state, cohortMax, applyGrid);
246+
//Get grid where x dim fills half gpu and y dim is number of tensors.
247+
//This will have cating two tensors fill the entire grid, but prevent
248+
//many threads from needlessly load meta data if their sizes is small.
249+
dim3 catGrid;
250+
getCatGrid(state, j, catGrid);
250251

251-
// Next, we set our grid's y component to be the number of tensors in
252-
// the batch. This will allow the kernel to determine which input
253-
// tensor it is responsible for copying
254-
applyGrid.y = j;
255252

256253
switch (maxDim) {
257254
case 1:

0 commit comments

Comments
 (0)