@@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
135135
136136 // col_matrix = filter * input_batch
137137 // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
138- blas.MatMul (filter, true , input_batch, false , &col_matrix);
138+ blas.MatMul (filter, true , input_batch, false , static_cast <T>(1.0 ),
139+ &col_matrix, static_cast <T>(0.0 ));
139140
140141 if (data_dim == 2U ) {
141142 // col2im: col_matrix -> dy
@@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
267268 // or
268269 // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
269270 // d, h, w)
270- blas.MatMul (filter, false , col_matrix, false , &input_grad_batch);
271+ blas.MatMul (filter, false , col_matrix, false , static_cast <T>(1.0 ),
272+ &input_grad_batch, static_cast <T>(0.0 ));
271273 }
272274 if (filter_grad) {
273275 // input batch
@@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
277279 // or
278280 // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
279281 // k_h * k_w)
280- blas.MatMul (in_batch, false , col_matrix, true , &filter_grad_);
282+ blas.MatMul (in_batch, false , col_matrix, true , static_cast <T>(1.0 ),
283+ &filter_grad_, static_cast <T>(1.0 ));
281284 }
282285 }
283286 }
0 commit comments