Skip to content

Commit 6c0129a

Browse files
committed
Refine the GemmConvGrad2DKernel.
1 parent f3669ca commit 6c0129a

File tree

1 file changed

+32
-37
lines changed

1 file changed

+32
-37
lines changed

paddle/operators/gemm_conv2d_op.h

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,13 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
109109
context.Input<Tensor>(framework::GradVarName("Output"));
110110
Tensor* input_grad =
111111
context.Output<Tensor>(framework::GradVarName("Input"));
112-
Tensor* filter_grad_ =
112+
Tensor* filter_grad =
113113
context.Output<Tensor>(framework::GradVarName("Filter"));
114114

115115
// The filter and filter_grad will be reshaped in the calculations,
116116
// so here use an assignment operation,
117117
// that avoids modifying the variable in the Scope.
118118
Tensor filter = *context.Input<Tensor>("Filter");
119-
Tensor filter_grad;
120-
if (filter_grad_) {
121-
filter_grad_->mutable_data<T>(context.GetPlace());
122-
filter_grad = *filter_grad_;
123-
}
124119

125120
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
126121
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
@@ -165,43 +160,28 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
165160
filter.numel() / filter.dims()[0]};
166161
filter.Resize(filter_matrix_shape);
167162

168-
if (filter_grad_) {
169-
filter_grad.Resize(filter_matrix_shape);
170-
auto t1 = framework::EigenVector<T>::Flatten(filter_grad);
171-
t1.device(context.GetEigenDevice<Place>()) =
172-
t1.constant(static_cast<T>(0));
173-
}
174-
175-
if (input_grad) {
176-
input_grad->mutable_data<T>(context.GetPlace());
177-
auto t2 = framework::EigenVector<T>::Flatten(*input_grad);
178-
t2.device(context.GetEigenDevice<Place>()) =
179-
t2.constant(static_cast<T>(0));
180-
}
181-
182163
auto* device_context =
183164
const_cast<platform::DeviceContext*>(context.device_context_);
184165

185166
// convolution backward input operator: gemm + col2im
186167
// convolution backward weight operator: im2col + gemm
187168
int in_step = input_channels / groups;
188169
int out_step = output_channels / groups;
189-
Tensor in_grad_batch;
190-
Tensor in_batch;
191-
for (int i = 0; i < batch_size; i++) {
192-
Tensor out_grad_batch =
193-
output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
194-
if (input_grad) {
195-
in_grad_batch = input_grad->Slice<T>(i, i + 1).Resize(input_shape);
196-
}
197-
if (filter_grad_) {
198-
in_batch = input->Slice<T>(i, i + 1).Resize(input_shape);
199-
}
200-
for (int g = 0; g < groups; g++) {
201-
Tensor out_grad_slice =
202-
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
203-
if (input_grad) {
170+
171+
if (input_grad) {
172+
input_grad->mutable_data<T>(context.GetPlace());
173+
auto t = framework::EigenVector<T>::Flatten(*input_grad);
174+
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
175+
176+
for (int i = 0; i < batch_size; i++) {
177+
Tensor out_grad_batch =
178+
output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
179+
Tensor in_grad_batch =
180+
input_grad->Slice<T>(i, i + 1).Resize(input_shape);
181+
for (int g = 0; g < groups; g++) {
204182
// gemm
183+
Tensor out_grad_slice =
184+
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
205185
Tensor filter_slice =
206186
filter.Slice<T>(g * out_step, (g + 1) * out_step);
207187
math::matmul<Place, T>(filter_slice, true, out_grad_slice, false,
@@ -213,16 +193,31 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
213193
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0],
214194
paddings[1], device_context);
215195
}
196+
}
197+
}
216198

217-
if (filter_grad_) {
199+
if (filter_grad) {
200+
filter_grad->mutable_data<T>(context.GetPlace());
201+
Tensor filter_grad_ = *filter_grad;
202+
filter_grad_.Resize(filter_matrix_shape);
203+
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
204+
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
205+
206+
for (int i = 0; i < batch_size; i++) {
207+
Tensor out_grad_batch =
208+
output_grad->Slice<T>(i, i + 1).Resize(output_matrix_shape);
209+
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_shape);
210+
for (int g = 0; g < groups; g++) {
218211
// im2col
212+
Tensor out_grad_slice =
213+
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
219214
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
220215
im2col(in_slice, col, strides[0], strides[1], paddings[0],
221216
paddings[1], device_context);
222217

223218
// gemm
224219
Tensor filter_grad_slice =
225-
filter_grad.Slice<T>(g * out_step, (g + 1) * out_step);
220+
filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
226221
math::matmul<Place, T>(out_grad_slice, false, col_matrix, true,
227222
T(1.0), &filter_grad_slice, T(1.0),
228223
device_context);

0 commit comments

Comments
 (0)