@@ -26,6 +26,24 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t
26
26
#define CAT_ARRAY_BATCH_SIZE 1024
27
27
#define CAT_ARRAY_MAX_INPUT_DIMS 4
28
28
29
+ inline bool getCatGrid (THCState* state, ptrdiff_t nTensors, dim3 & grid) {
30
+ int curDevice = -1 ;
31
+ cudaGetDevice (&curDevice);
32
+
33
+ if (curDevice == -1 ) {
34
+ return false ;
35
+ }
36
+
37
+ // Assume a reasonable number of SMs if no state is available
38
+ int numSM =
39
+ state ? THCState_getCurrentDeviceProperties (state)->multiProcessorCount : 15 ;
40
+ // X dim of grid for cat array cooperates on a single tensor in the cat.
41
+ // Given half of the GPU, full utilization will always occur.
42
+ grid = dim3 ( 2LL * numSM, (long long ) nTensors );
43
+
44
+ return true ;
45
+ }
46
+
29
47
// Similar to any other IndexToOffset calculation for copying along a given dimension.
30
48
template <typename IndexType, int Dims>
31
49
struct CatArrIndexToOffset {
@@ -77,26 +95,36 @@ struct OutputTensorSizeStride {
77
95
*
78
96
* The most important assumption made is that the input tensors are contiguous.
79
97
*/
98
+
99
+
100
+
80
101
template <typename T, typename IndexType, int Dims>
81
102
__global__ void CatArrayBatchedCopy (
82
103
T* output,
83
104
CatArrInputTensor<T, IndexType>* inputs,
84
105
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
85
106
const int concatDim,
86
107
IndexType dimStride) {
87
- T* data = inputs[blockIdx .y ].input ;
88
- IndexType offset = inputs[blockIdx .y ].offset ;
89
- IndexType dimSize = inputs[blockIdx .y ].dimSize ;
90
- IndexType nElements = inputs[blockIdx .y ].nElements ;
91
- IndexType dataOffset = offset * dimStride;
92
-
93
- for (IndexType linearIndex = blockIdx .x * blockDim .x + threadIdx .x ;
94
- linearIndex < nElements;
95
- linearIndex += gridDim .x * blockDim .x ) {
108
+
109
+ IndexType tid = blockIdx .x * blockDim .x + threadIdx .x ;
110
+ IndexType nElements = inputs[blockIdx .y ].nElements ;
111
+
112
+ if (tid >= nElements) return ;
113
+
114
+ T* data = inputs[blockIdx .y ].input ;
115
+ IndexType offset = inputs[blockIdx .y ].offset ;
116
+ IndexType dimSize = inputs[blockIdx .y ].dimSize ;
117
+ IndexType dataOffset = offset * dimStride;
118
+
119
+ IndexType stride = gridDim .x * blockDim .x ;
120
+
121
+ while ( tid < nElements){
96
122
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute (
97
- os.outputSize , os.outputStride , dimSize, concatDim, linearIndex);
98
- output[dataOffset + elementOffset] = data[linearIndex];
99
- }
123
+ os.outputSize , os.outputStride , dimSize, concatDim, tid);
124
+ output[dataOffset + elementOffset] = data[tid];
125
+
126
+ tid += stride;
127
+ }
100
128
}
101
129
102
130
#endif
0 commit comments