@@ -66,68 +66,66 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
6666}
6767
6868template <typename T>
69- __global__ void KernelConcat (T** inputs , const int input_col ,
70- const int output_rows , const int output_cols ,
71- T* output ) {
69+ __global__ void KernelConcat (T** inputs_data , const int fixed_in_col ,
70+ const int out_rows , const int out_cols ,
71+ T* output_data ) {
7272 int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
73- double inv_input_col = 1.0 / input_col;
74- for (; tid_x < output_cols; tid_x += blockDim .x * gridDim .x ) {
75- int split = tid_x * inv_input_col;
76- int in_offset = tid_x - split * input_col;
77- T* input_ptr = inputs[split];
73+ for (; tid_x < out_cols; tid_x += blockDim .x * gridDim .x ) {
74+ int split = tid_x * 1.0 / fixed_in_col;
75+ int in_offset = tid_x - split * fixed_in_col;
76+ T* input_ptr = inputs_data[split];
7877 int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
79- for (; tid_y < output_rows ; tid_y += blockDim .y * gridDim .y ) {
80- output [tid_y * output_cols + tid_x] =
81- input_ptr[tid_y * input_col + in_offset];
78+ for (; tid_y < out_rows ; tid_y += blockDim .y * gridDim .y ) {
79+ output_data [tid_y * out_cols + tid_x] =
80+ input_ptr[tid_y * fixed_in_col + in_offset];
8281 }
8382 }
8483}
8584
8685template <typename T>
87- __global__ void KernelConcatGrad (const T* input , const int input_row ,
88- const int input_col , const int * output_cols ,
89- int col_size , T** outputs ) {
86+ __global__ void KernelConcatGrad (const T* input_data , const int in_row ,
87+ const int in_col , const int * out_cols ,
88+ int out_cols_size , T** outputs_data ) {
9089 int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
91- int segment = upper_bound<int >(output_cols, col_size , tid_x) - 1 ;
92- int curr_offset = output_cols [segment];
90+ int segment = upper_bound<int >(out_cols, out_cols_size , tid_x) - 1 ;
91+ int curr_offset = out_cols [segment];
9392 int curr_segment = segment;
94- for (; tid_x < input_col ; tid_x += blockDim .x * gridDim .x ) {
93+ for (; tid_x < in_col ; tid_x += blockDim .x * gridDim .x ) {
9594 T curr_col_offset;
96- while ((curr_col_offset = output_cols [curr_segment + 1 ]) <= tid_x) {
95+ while ((curr_col_offset = out_cols [curr_segment + 1 ]) <= tid_x) {
9796 curr_offset = curr_col_offset;
9897 ++curr_segment;
9998 }
10099
101100 int local_col = tid_x - curr_offset;
102101 int segment_width = curr_col_offset - curr_offset;
103- T* output_ptr = outputs [curr_segment];
102+ T* output_ptr = outputs_data [curr_segment];
104103 int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
105- for (; tid_y < input_row ; tid_y += blockDim .y * gridDim .y )
104+ for (; tid_y < in_row ; tid_y += blockDim .y * gridDim .y )
106105 output_ptr[tid_y * segment_width + local_col] =
107- input [tid_y * input_col + tid_x];
106+ input_data [tid_y * in_col + tid_x];
108107 }
109108}
110109
111110template <typename T>
112- __global__ void KernelConcatGrad (const T* input , const int input_row ,
113- const int input_col , const int output_cols ,
114- T** outputs ) {
111+ __global__ void KernelConcatGrad (const T* input_data , const int in_row ,
112+ const int in_col , const int fixed_out_col ,
113+ T** outputs_data ) {
115114 int tid_x = blockIdx .x * blockDim .x + threadIdx .x ;
116- double inv_input_col = 1.0 / input_col;
117- for (; tid_x < input_col; tid_x += blockDim .x * gridDim .x ) {
118- int split = tid_x * inv_input_col;
119- int in_offset = tid_x - split * input_col;
120- T* output_ptr = outputs[split];
115+ for (; tid_x < in_col; tid_x += blockDim .x * gridDim .x ) {
116+ int split = tid_x / fixed_out_col;
117+ int in_offset = tid_x - split * fixed_out_col;
118+ T* output_ptr = outputs_data[split];
121119 int tid_y = blockIdx .y * blockDim .y + threadIdx .y ;
122- for (; tid_y < input_row ; tid_y += blockDim .y * gridDim .y )
123- output_ptr[tid_y * output_cols + in_offset] =
124- input [tid_y * input_col + tid_x];
120+ for (; tid_y < in_row ; tid_y += blockDim .y * gridDim .y )
121+ output_ptr[tid_y * fixed_out_col + in_offset] =
122+ input_data [tid_y * in_col + tid_x];
125123 }
126124}
127125
128126/*
129127 * All tensors' dimension should be the same and the values of
130- * each dimension are the same, except the axis dimension.
128+ * each dimension must be the same, except the axis dimension.
131129 */
132130template <typename T>
133131class ConcatFunctor <platform::CUDADeviceContext, T> {
@@ -136,41 +134,40 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
136134 const std::vector<framework::Tensor>& input, const int axis,
137135 framework::Tensor* output) {
138136 // TODO(zcd): Add input data validity checking
139- int num = input.size ();
140- int rows = 1 ;
137+ int in_num = input.size ();
138+ int in_row = 1 ;
141139 auto dim_0 = input[0 ].dims ();
142140 for (int i = 0 ; i < axis; ++i) {
143- rows *= dim_0[i];
141+ in_row *= dim_0[i];
144142 }
145- int cols = input[0 ].numel () / rows ;
146- int out_rows = rows, out_cols = 0 ;
143+ int in_col = input[0 ].numel () / in_row ;
144+ int out_row = in_row, out_col = 0 ;
147145
148- framework::Vector<int16_t > inputs_data (num * sizeof (T*) / 2 );
149- framework::Vector<int > inputs_cols (num + 1 );
150- inputs_cols[0 ] = 0 ;
146+ framework::Vector<int16_t > inputs_data (in_num * sizeof (T*) / 2 );
147+ framework::Vector<int > inputs_col (in_num + 1 );
151148 T** inputs_ptr = reinterpret_cast <T**>(inputs_data.data ());
152149
150+ inputs_col[0 ] = 0 ;
153151 bool sameShape = true ;
154- for (int i = 0 ; i < num ; ++i) {
155- int t_cols = input[i].numel () / rows ;
152+ for (int i = 0 ; i < in_num ; ++i) {
153+ int t_cols = input[i].numel () / in_row ;
156154 if (sameShape) {
157- if (t_cols != cols ) sameShape = false ;
155+ if (t_cols != in_col ) sameShape = false ;
158156 }
159- out_cols += t_cols;
160- inputs_cols [i + 1 ] = out_cols ;
157+ out_col += t_cols;
158+ inputs_col [i + 1 ] = out_col ;
161159 inputs_ptr[i] = const_cast <T*>(input[i].data <T>());
162160 }
163161
164- T** ins_gpu =
162+ T** dev_ins_data =
165163 reinterpret_cast <T**>(inputs_data.CUDAMutableData (context.GetPlace ()));
166- const int * ins_col_gpu = inputs_cols.CUDAData (context.GetPlace ());
167164
168165 // computation
169166 // set the thread block and grid according to CurrentDeviceId
170167 const int kThreadsPerBlock = 1024 ;
171168 int block_cols = kThreadsPerBlock ;
172- if (out_cols < kThreadsPerBlock ) { // block_cols is aligned by 32.
173- block_cols = ((out_cols + 31 ) >> 5 ) << 5 ;
169+ if (out_col < kThreadsPerBlock ) { // block_cols is aligned by 32.
170+ block_cols = ((out_col + 31 ) >> 5 ) << 5 ;
174171 }
175172 int block_rows = kThreadsPerBlock / block_cols;
176173 dim3 block_size = dim3 (block_cols, block_rows, 1 );
@@ -179,25 +176,26 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
179176 int max_blocks = std::max (max_threads / kThreadsPerBlock , 1 );
180177
181178 int grid_cols =
182- std::min ((out_cols + block_cols - 1 ) / block_cols, max_blocks);
179+ std::min ((out_col + block_cols - 1 ) / block_cols, max_blocks);
183180 int grid_rows =
184- std::min (max_blocks / grid_cols, std::max (out_rows / block_rows, 1 ));
181+ std::min (max_blocks / grid_cols, std::max (out_row / block_rows, 1 ));
185182 dim3 grid_size = dim3 (grid_cols, grid_rows, 1 );
186183
187184 if (sameShape) {
188185 KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
189- ins_gpu, cols, out_rows, out_cols , output->data <T>());
186+ dev_ins_data, in_col, out_row, out_col , output->data <T>());
190187 } else {
188+ const int * dev_ins_col_data = inputs_col.CUDAData (context.GetPlace ());
191189 KernelConcat<<<grid_size, block_size, 0 , context.stream()>>> (
192- ins_gpu, ins_col_gpu , static_cast <int >(inputs_cols .size ()), out_rows ,
193- out_cols , output->data <T>());
190+ dev_ins_data, dev_ins_col_data , static_cast <int >(inputs_col .size ()),
191+ out_row, out_col , output->data <T>());
194192 }
195193 }
196194};
197195
198196/*
199197 * All tensors' dimension should be the same and the values of
200- * each dimension are the same, except the axis dimension.
198+ * each dimension must be the same, except the axis dimension.
201199 */
202200template <typename T>
203201class ConcatGradFunctor <platform::CUDADeviceContext, T> {
@@ -206,41 +204,40 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
206204 const framework::Tensor& input, const int axis,
207205 std::vector<framework::Tensor>& outputs) {
208206 // TODO(zcd): Add input data validity checking
209- int num = outputs.size ();
210- int input_row = 1 ;
207+ int o_num = outputs.size ();
208+ int out_row = 1 ;
211209 auto dim_0 = outputs[0 ].dims ();
212210 for (int i = 0 ; i < axis; ++i) {
213- input_row *= dim_0[i];
211+ out_row *= dim_0[i];
214212 }
215213
216- int output_col_0 = outputs[0 ].numel () / input_row ;
217- int input_col = 0 ;
214+ int out_col = outputs[0 ].numel () / out_row ;
215+ int in_col = 0 , in_row = out_row ;
218216 bool sameShape = true ;
219217
220- framework::Vector<int16_t > outputs_data (num * sizeof (T*) / 2 );
221- framework::Vector<int > outputs_cols (num + 1 );
222- outputs_cols[0 ] = 0 ;
218+ framework::Vector<int16_t > outputs_data (o_num * sizeof (T*) / 2 );
219+ framework::Vector<int > outputs_cols (o_num + 1 );
223220 T** outputs_ptr = reinterpret_cast <T**>(outputs_data.data ());
224221
225- for (int i = 0 ; i < num; ++i) {
226- int t_col = outputs[i].numel () / input_row;
222+ outputs_cols[0 ] = 0 ;
223+ for (int i = 0 ; i < o_num; ++i) {
224+ int t_col = outputs[i].numel () / out_row;
227225 if (sameShape) {
228- if (t_col != output_col_0 ) sameShape = false ;
226+ if (t_col != out_col ) sameShape = false ;
229227 }
230- input_col += t_col;
231- outputs_cols[i + 1 ] = input_col ;
228+ in_col += t_col;
229+ outputs_cols[i + 1 ] = in_col ;
232230 outputs_ptr[i] = outputs[i].data <T>();
233231 }
234232
235- T** outs_gpu =
233+ T** dev_out_gpu_data =
236234 reinterpret_cast <T**>(outputs_data.CUDAMutableData (context.GetPlace ()));
237- const int * outs_col_gpu = outputs_cols.CUDAData (context.GetPlace ());
238235
239236 // computation
240237 const int kThreadsPerBlock = 1024 ;
241238 int block_cols = kThreadsPerBlock ;
242- if (input_col < kThreadsPerBlock ) { // block_cols is aligned by 32.
243- block_cols = ((input_col + 31 ) >> 5 ) << 5 ;
239+ if (in_col < kThreadsPerBlock ) { // block_cols is aligned by 32.
240+ block_cols = ((in_col + 31 ) >> 5 ) << 5 ;
244241 }
245242 int block_rows = kThreadsPerBlock / block_cols;
246243 dim3 block_size = dim3 (block_cols, block_rows, 1 );
@@ -249,18 +246,19 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
249246 int max_blocks = std::max (max_threads / kThreadsPerBlock , 1 );
250247
251248 int grid_cols =
252- std::min ((input_col + block_cols - 1 ) / block_cols, max_blocks);
249+ std::min ((in_col + block_cols - 1 ) / block_cols, max_blocks);
253250 int grid_rows =
254- std::min (max_blocks / grid_cols, std::max (input_row / block_rows, 1 ));
251+ std::min (max_blocks / grid_cols, std::max (out_row / block_rows, 1 ));
255252 dim3 grid_size = dim3 (grid_cols, grid_rows, 1 );
256253
257254 if (sameShape) {
258255 KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
259- input.data <T>(), input_row, input_col, output_col_0, outs_gpu );
256+ input.data <T>(), in_row, in_col, out_col, dev_out_gpu_data );
260257 } else {
258+ const int * dev_outs_col_data = outputs_cols.CUDAData (context.GetPlace ());
261259 KernelConcatGrad<<<grid_size, block_size, 0 , context.stream()>>> (
262- input.data <T>(), input_row, input_col, outs_col_gpu ,
263- static_cast <int >(outputs_cols.size ()), outs_gpu );
260+ input.data <T>(), in_row, in_col, dev_outs_col_data ,
261+ static_cast <int >(outputs_cols.size ()), dev_out_gpu_data );
264262 }
265263 }
266264};
0 commit comments