@@ -109,18 +109,13 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
109109 context.Input <Tensor>(framework::GradVarName (" Output" ));
110110 Tensor* input_grad =
111111 context.Output <Tensor>(framework::GradVarName (" Input" ));
112- Tensor* filter_grad_ =
112+ Tensor* filter_grad =
113113 context.Output <Tensor>(framework::GradVarName (" Filter" ));
114114
115115 // The filter and filter_grad will be reshaped in the calculations,
116116 // so here use an assignment operation,
117117 // that avoids modifying the variable in the Scope.
118118 Tensor filter = *context.Input <Tensor>(" Filter" );
119- Tensor filter_grad;
120- if (filter_grad_) {
121- filter_grad_->mutable_data <T>(context.GetPlace ());
122- filter_grad = *filter_grad_;
123- }
124119
125120 std::vector<int > strides = context.Attr <std::vector<int >>(" strides" );
126121 std::vector<int > paddings = context.Attr <std::vector<int >>(" paddings" );
@@ -165,43 +160,28 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
165160 filter.numel () / filter.dims ()[0 ]};
166161 filter.Resize (filter_matrix_shape);
167162
168- if (filter_grad_) {
169- filter_grad.Resize (filter_matrix_shape);
170- auto t1 = framework::EigenVector<T>::Flatten (filter_grad);
171- t1.device (context.GetEigenDevice <Place>()) =
172- t1.constant (static_cast <T>(0 ));
173- }
174-
175- if (input_grad) {
176- input_grad->mutable_data <T>(context.GetPlace ());
177- auto t2 = framework::EigenVector<T>::Flatten (*input_grad);
178- t2.device (context.GetEigenDevice <Place>()) =
179- t2.constant (static_cast <T>(0 ));
180- }
181-
182163 auto * device_context =
183164 const_cast <platform::DeviceContext*>(context.device_context_ );
184165
185166 // convolution backward input operator: gemm + col2im
186167 // convolution backward weight operator: im2col + gemm
187168 int in_step = input_channels / groups;
188169 int out_step = output_channels / groups;
189- Tensor in_grad_batch;
190- Tensor in_batch;
191- for (int i = 0 ; i < batch_size; i++) {
192- Tensor out_grad_batch =
193- output_grad->Slice <T>(i, i + 1 ).Resize (output_matrix_shape);
194- if (input_grad) {
195- in_grad_batch = input_grad->Slice <T>(i, i + 1 ).Resize (input_shape);
196- }
197- if (filter_grad_) {
198- in_batch = input->Slice <T>(i, i + 1 ).Resize (input_shape);
199- }
200- for (int g = 0 ; g < groups; g++) {
201- Tensor out_grad_slice =
202- out_grad_batch.Slice <T>(g * out_step, (g + 1 ) * out_step);
203- if (input_grad) {
170+
171+ if (input_grad) {
172+ input_grad->mutable_data <T>(context.GetPlace ());
173+ auto t = framework::EigenVector<T>::Flatten (*input_grad);
174+ t.device (context.GetEigenDevice <Place>()) = t.constant (static_cast <T>(0 ));
175+
176+ for (int i = 0 ; i < batch_size; i++) {
177+ Tensor out_grad_batch =
178+ output_grad->Slice <T>(i, i + 1 ).Resize (output_matrix_shape);
179+ Tensor in_grad_batch =
180+ input_grad->Slice <T>(i, i + 1 ).Resize (input_shape);
181+ for (int g = 0 ; g < groups; g++) {
204182 // gemm
183+ Tensor out_grad_slice =
184+ out_grad_batch.Slice <T>(g * out_step, (g + 1 ) * out_step);
205185 Tensor filter_slice =
206186 filter.Slice <T>(g * out_step, (g + 1 ) * out_step);
207187 math::matmul<Place, T>(filter_slice, true , out_grad_slice, false ,
@@ -213,16 +193,31 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
213193 col2im (in_grad_slice, col, strides[0 ], strides[1 ], paddings[0 ],
214194 paddings[1 ], device_context);
215195 }
196+ }
197+ }
216198
217- if (filter_grad_) {
199+ if (filter_grad) {
200+ filter_grad->mutable_data <T>(context.GetPlace ());
201+ Tensor filter_grad_ = *filter_grad;
202+ filter_grad_.Resize (filter_matrix_shape);
203+ auto t = framework::EigenVector<T>::Flatten (filter_grad_);
204+ t.device (context.GetEigenDevice <Place>()) = t.constant (static_cast <T>(0 ));
205+
206+ for (int i = 0 ; i < batch_size; i++) {
207+ Tensor out_grad_batch =
208+ output_grad->Slice <T>(i, i + 1 ).Resize (output_matrix_shape);
209+ Tensor in_batch = input->Slice <T>(i, i + 1 ).Resize (input_shape);
210+ for (int g = 0 ; g < groups; g++) {
218211 // im2col
212+ Tensor out_grad_slice =
213+ out_grad_batch.Slice <T>(g * out_step, (g + 1 ) * out_step);
219214 Tensor in_slice = in_batch.Slice <T>(g * in_step, (g + 1 ) * in_step);
220215 im2col (in_slice, col, strides[0 ], strides[1 ], paddings[0 ],
221216 paddings[1 ], device_context);
222217
223218 // gemm
224219 Tensor filter_grad_slice =
225- filter_grad .Slice <T>(g * out_step, (g + 1 ) * out_step);
220+ filter_grad_ .Slice <T>(g * out_step, (g + 1 ) * out_step);
226221 math::matmul<Place, T>(out_grad_slice, false , col_matrix, true ,
227222 T (1.0 ), &filter_grad_slice, T (1.0 ),
228223 device_context);
0 commit comments