Skip to content

Commit 4a92e89

Browse files
author
chengduo
authored
Merge pull request #9337 from chengduoZH/feature/fix_concat
Fix concat_op
2 parents 12856c5 + aca9180 commit 4a92e89

File tree

3 files changed

+101
-87
lines changed

3 files changed

+101
-87
lines changed

paddle/fluid/operators/math/concat.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace math {
2020

2121
/*
2222
* All tensors' dimension should be the same and the values of
23-
* each dimension are the same, except the axis dimension.
23+
* each dimension must be the same, except the axis dimension.
2424
*/
2525
template <typename T>
2626
class ConcatFunctor<platform::CPUDeviceContext, T> {
@@ -63,7 +63,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
6363

6464
/*
6565
* All tensors' dimension should be the same and the values of
66-
* each dimension are the same, except the axis dimension.
66+
* each dimension must be the same, except the axis dimension.
6767
*/
6868
template <typename T>
6969
class ConcatGradFunctor<platform::CPUDeviceContext, T> {

paddle/fluid/operators/math/concat.cu

Lines changed: 76 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -66,68 +66,66 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
6666
}
6767

6868
template <typename T>
69-
__global__ void KernelConcat(T** inputs, const int input_col,
70-
const int output_rows, const int output_cols,
71-
T* output) {
69+
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
70+
const int out_rows, const int out_cols,
71+
T* output_data) {
7272
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
73-
double inv_input_col = 1.0 / input_col;
74-
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
75-
int split = tid_x * inv_input_col;
76-
int in_offset = tid_x - split * input_col;
77-
T* input_ptr = inputs[split];
73+
for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
74+
int split = tid_x * 1.0 / fixed_in_col;
75+
int in_offset = tid_x - split * fixed_in_col;
76+
T* input_ptr = inputs_data[split];
7877
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
79-
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
80-
output[tid_y * output_cols + tid_x] =
81-
input_ptr[tid_y * input_col + in_offset];
78+
for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
79+
output_data[tid_y * out_cols + tid_x] =
80+
input_ptr[tid_y * fixed_in_col + in_offset];
8281
}
8382
}
8483
}
8584

8685
template <typename T>
87-
__global__ void KernelConcatGrad(const T* input, const int input_row,
88-
const int input_col, const int* output_cols,
89-
int col_size, T** outputs) {
86+
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
87+
const int in_col, const int* out_cols,
88+
int out_cols_size, T** outputs_data) {
9089
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
91-
int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
92-
int curr_offset = output_cols[segment];
90+
int segment = upper_bound<int>(out_cols, out_cols_size, tid_x) - 1;
91+
int curr_offset = out_cols[segment];
9392
int curr_segment = segment;
94-
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
93+
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
9594
T curr_col_offset;
96-
while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
95+
while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) {
9796
curr_offset = curr_col_offset;
9897
++curr_segment;
9998
}
10099

101100
int local_col = tid_x - curr_offset;
102101
int segment_width = curr_col_offset - curr_offset;
103-
T* output_ptr = outputs[curr_segment];
102+
T* output_ptr = outputs_data[curr_segment];
104103
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
105-
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
104+
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
106105
output_ptr[tid_y * segment_width + local_col] =
107-
input[tid_y * input_col + tid_x];
106+
input_data[tid_y * in_col + tid_x];
108107
}
109108
}
110109

111110
template <typename T>
112-
__global__ void KernelConcatGrad(const T* input, const int input_row,
113-
const int input_col, const int output_cols,
114-
T** outputs) {
111+
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
112+
const int in_col, const int fixed_out_col,
113+
T** outputs_data) {
115114
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
116-
double inv_input_col = 1.0 / input_col;
117-
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
118-
int split = tid_x * inv_input_col;
119-
int in_offset = tid_x - split * input_col;
120-
T* output_ptr = outputs[split];
115+
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
116+
int split = tid_x / fixed_out_col;
117+
int in_offset = tid_x - split * fixed_out_col;
118+
T* output_ptr = outputs_data[split];
121119
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
122-
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
123-
output_ptr[tid_y * output_cols + in_offset] =
124-
input[tid_y * input_col + tid_x];
120+
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
121+
output_ptr[tid_y * fixed_out_col + in_offset] =
122+
input_data[tid_y * in_col + tid_x];
125123
}
126124
}
127125

128126
/*
129127
* All tensors' dimension should be the same and the values of
130-
* each dimension are the same, except the axis dimension.
128+
* each dimension must be the same, except the axis dimension.
131129
*/
132130
template <typename T>
133131
class ConcatFunctor<platform::CUDADeviceContext, T> {
@@ -136,41 +134,40 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
136134
const std::vector<framework::Tensor>& input, const int axis,
137135
framework::Tensor* output) {
138136
// TODO(zcd): Add input data validity checking
139-
int num = input.size();
140-
int rows = 1;
137+
int in_num = input.size();
138+
int in_row = 1;
141139
auto dim_0 = input[0].dims();
142140
for (int i = 0; i < axis; ++i) {
143-
rows *= dim_0[i];
141+
in_row *= dim_0[i];
144142
}
145-
int cols = input[0].numel() / rows;
146-
int out_rows = rows, out_cols = 0;
143+
int in_col = input[0].numel() / in_row;
144+
int out_row = in_row, out_col = 0;
147145

148-
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
149-
framework::Vector<int> inputs_cols(num + 1);
150-
inputs_cols[0] = 0;
146+
framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
147+
framework::Vector<int> inputs_col(in_num + 1);
151148
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
152149

150+
inputs_col[0] = 0;
153151
bool sameShape = true;
154-
for (int i = 0; i < num; ++i) {
155-
int t_cols = input[i].numel() / rows;
152+
for (int i = 0; i < in_num; ++i) {
153+
int t_cols = input[i].numel() / in_row;
156154
if (sameShape) {
157-
if (t_cols != cols) sameShape = false;
155+
if (t_cols != in_col) sameShape = false;
158156
}
159-
out_cols += t_cols;
160-
inputs_cols[i + 1] = out_cols;
157+
out_col += t_cols;
158+
inputs_col[i + 1] = out_col;
161159
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
162160
}
163161

164-
T** ins_gpu =
162+
T** dev_ins_data =
165163
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
166-
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
167164

168165
// computation
169166
// set the thread block and grid according to CurrentDeviceId
170167
const int kThreadsPerBlock = 1024;
171168
int block_cols = kThreadsPerBlock;
172-
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
173-
block_cols = ((out_cols + 31) >> 5) << 5;
169+
if (out_col < kThreadsPerBlock) { // block_cols is aligned by 32.
170+
block_cols = ((out_col + 31) >> 5) << 5;
174171
}
175172
int block_rows = kThreadsPerBlock / block_cols;
176173
dim3 block_size = dim3(block_cols, block_rows, 1);
@@ -179,25 +176,26 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
179176
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
180177

181178
int grid_cols =
182-
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
179+
std::min((out_col + block_cols - 1) / block_cols, max_blocks);
183180
int grid_rows =
184-
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
181+
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
185182
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
186183

187184
if (sameShape) {
188185
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
189-
ins_gpu, cols, out_rows, out_cols, output->data<T>());
186+
dev_ins_data, in_col, out_row, out_col, output->data<T>());
190187
} else {
188+
const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
191189
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
192-
ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
193-
out_cols, output->data<T>());
190+
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
191+
out_row, out_col, output->data<T>());
194192
}
195193
}
196194
};
197195

198196
/*
199197
* All tensors' dimension should be the same and the values of
200-
* each dimension are the same, except the axis dimension.
198+
* each dimension must be the same, except the axis dimension.
201199
*/
202200
template <typename T>
203201
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
@@ -206,41 +204,40 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
206204
const framework::Tensor& input, const int axis,
207205
std::vector<framework::Tensor>& outputs) {
208206
// TODO(zcd): Add input data validity checking
209-
int num = outputs.size();
210-
int input_row = 1;
207+
int o_num = outputs.size();
208+
int out_row = 1;
211209
auto dim_0 = outputs[0].dims();
212210
for (int i = 0; i < axis; ++i) {
213-
input_row *= dim_0[i];
211+
out_row *= dim_0[i];
214212
}
215213

216-
int output_col_0 = outputs[0].numel() / input_row;
217-
int input_col = 0;
214+
int out_col = outputs[0].numel() / out_row;
215+
int in_col = 0, in_row = out_row;
218216
bool sameShape = true;
219217

220-
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
221-
framework::Vector<int> outputs_cols(num + 1);
222-
outputs_cols[0] = 0;
218+
framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
219+
framework::Vector<int> outputs_cols(o_num + 1);
223220
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
224221

225-
for (int i = 0; i < num; ++i) {
226-
int t_col = outputs[i].numel() / input_row;
222+
outputs_cols[0] = 0;
223+
for (int i = 0; i < o_num; ++i) {
224+
int t_col = outputs[i].numel() / out_row;
227225
if (sameShape) {
228-
if (t_col != output_col_0) sameShape = false;
226+
if (t_col != out_col) sameShape = false;
229227
}
230-
input_col += t_col;
231-
outputs_cols[i + 1] = input_col;
228+
in_col += t_col;
229+
outputs_cols[i + 1] = in_col;
232230
outputs_ptr[i] = outputs[i].data<T>();
233231
}
234232

235-
T** outs_gpu =
233+
T** dev_out_gpu_data =
236234
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
237-
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
238235

239236
// computation
240237
const int kThreadsPerBlock = 1024;
241238
int block_cols = kThreadsPerBlock;
242-
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
243-
block_cols = ((input_col + 31) >> 5) << 5;
239+
if (in_col < kThreadsPerBlock) { // block_cols is aligned by 32.
240+
block_cols = ((in_col + 31) >> 5) << 5;
244241
}
245242
int block_rows = kThreadsPerBlock / block_cols;
246243
dim3 block_size = dim3(block_cols, block_rows, 1);
@@ -249,18 +246,19 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
249246
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
250247

251248
int grid_cols =
252-
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
249+
std::min((in_col + block_cols - 1) / block_cols, max_blocks);
253250
int grid_rows =
254-
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
251+
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
255252
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
256253

257254
if (sameShape) {
258255
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
259-
input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
256+
input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data);
260257
} else {
258+
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
261259
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
262-
input.data<T>(), input_row, input_col, outs_col_gpu,
263-
static_cast<int>(outputs_cols.size()), outs_gpu);
260+
input.data<T>(), in_row, in_col, dev_outs_col_data,
261+
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
264262
}
265263
}
266264
};

python/paddle/fluid/tests/unittests/test_concat_op.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,35 @@
2020
class TestConcatOp(OpTest):
2121
def setUp(self):
2222
self.op_type = "concat"
23-
x0 = np.random.random((2, 1, 4, 5)).astype('float32')
24-
x1 = np.random.random((2, 2, 4, 5)).astype('float32')
25-
x2 = np.random.random((2, 3, 4, 5)).astype('float32')
26-
axis = 1
27-
self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]}
28-
self.attrs = {'axis': axis}
29-
self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)}
23+
self.init_test_data()
24+
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
25+
self.attrs = {'axis': self.axis}
26+
self.outputs = {
27+
'Out': np.concatenate(
28+
(self.x0, self.x1, self.x2), axis=self.axis)
29+
}
3030

3131
def test_check_output(self):
3232
self.check_output()
3333

3434
def test_check_grad(self):
3535
self.check_grad(['x0'], 'Out')
36+
self.check_grad(['x1'], 'Out')
37+
self.check_grad(['x2'], 'Out')
38+
39+
def init_test_data(self):
40+
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
41+
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
42+
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
43+
self.axis = 1
44+
45+
46+
class TestConcatOp2(OpTest):
47+
def init_test_data(self):
48+
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32')
49+
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32')
50+
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
51+
self.axis = 1
3652

3753

3854
if __name__ == '__main__':

0 commit comments

Comments
 (0)