@@ -25,7 +25,7 @@ namespace operators {
2525 * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
2626 * original x_dim is returned.
2727 */
28- static framework::DDim RowMatrixFromVector (const framework::DDim& x_dim) {
28+ static framework::DDim RowMatrixFromVector (const framework::DDim & x_dim) {
2929 if (x_dim.size () > 1 ) {
3030 return x_dim;
3131 }
@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
3636 * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
3737 * original y_dim is returned.
3838 */
39- static framework::DDim ColumnMatrixFromVector (const framework::DDim& y_dim) {
39+ static framework::DDim ColumnMatrixFromVector (const framework::DDim & y_dim) {
4040 if (y_dim.size () > 1 ) {
4141 return y_dim;
4242 }
@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
4646template <typename DeviceContext, typename T>
4747class MatMulKernel : public framework ::OpKernel<T> {
4848 public:
49- void Compute (const framework::ExecutionContext& context) const override {
50- auto & x =
49+ void Compute (const framework::ExecutionContext & context) const override {
50+ auto & x =
5151 detail::Ref (context.Input <framework::Tensor>(" X" ), " Cannot find X" );
52- auto & y =
52+ auto & y =
5353 detail::Ref (context.Input <framework::Tensor>(" Y" ), " Cannot find Y" );
54- auto * out = context.Output <framework::Tensor>(" Out" );
54+ auto * out = context.Output <framework::Tensor>(" Out" );
5555 out->mutable_data <T>(context.GetPlace ());
5656
5757 auto blas = math::GetBlas<DeviceContext, T>(context);
@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
6565
6666// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
6767// Identity op if the tensor is not of rank 3.
68- static framework::Tensor FoldInitDims (const framework::Tensor& input) {
68+ static framework::Tensor FoldInitDims (const framework::Tensor & input) {
6969 auto output = input;
7070 auto in_dims = input.dims ();
7171 if (in_dims.size () == 3 ) {
@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
7878// (Warning: This requires transposing data and writes into new memory.)
7979// Identity op if the tensor is not of rank 3.
8080template <typename DeviceContext, typename T>
81- static framework::Tensor FoldHeadAndLastDims (const DeviceContext& context,
82- const framework::Tensor& input) {
81+ static framework::Tensor FoldHeadAndLastDims (const DeviceContext & context,
82+ const framework::Tensor & input) {
8383 auto in_dims = input.dims ();
8484 if (in_dims.size () != 3 ) {
8585 return input;
@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
102102 * If transposed, `H,W` will be swapped.
103103 */
104104static void ReshapeTensorIntoMatrixSequence (
105- framework::Tensor* x, const math::MatDescriptor& descriptor) {
105+ framework::Tensor * x, const math::MatDescriptor & descriptor) {
106106 int64_t h, w;
107107 h = descriptor.height_ ;
108108 w = descriptor.width_ ;
@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
130130 * If any of `X` and `Y` has batch size BatchSize, the out will have the
131131 * BatchSize.
132132 */
133- static void ReshapeXYOutIntoMatrixSequence (framework::Tensor* x,
134- framework::Tensor* y,
135- framework::Tensor* out, bool trans_x,
133+ static void ReshapeXYOutIntoMatrixSequence (framework::Tensor * x,
134+ framework::Tensor * y,
135+ framework::Tensor * out, bool trans_x,
136136 bool trans_y) {
137137 auto x_dim = RowMatrixFromVector (x->dims ());
138138 auto y_dim = ColumnMatrixFromVector (y->dims ());
@@ -177,29 +177,29 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
177177template <typename DeviceContext, typename T>
178178class MatMulGradKernel : public framework ::OpKernel<T> {
179179 public:
180- void MatMul (const framework::ExecutionContext& context,
181- const framework::Tensor& a, bool trans_a,
182- const framework::Tensor& b, bool trans_b,
183- framework::Tensor* out) const {
180+ void MatMul (const framework::ExecutionContext & context,
181+ const framework::Tensor & a, bool trans_a,
182+ const framework::Tensor & b, bool trans_b,
183+ framework::Tensor * out) const {
184184 out->mutable_data <T>(context.GetPlace ());
185185 auto blas = math::GetBlas<DeviceContext, T>(context);
186186 auto mat_dim_a = math::CreateMatrixDescriptor (a.dims (), 0 , trans_a);
187187 auto mat_dim_b = math::CreateMatrixDescriptor (b.dims (), 0 , trans_b);
188188 blas.MatMul (a, mat_dim_a, b, mat_dim_b, T (1 ), out, T (0 ));
189189 }
190190
191- void CalcInputGrad (const framework::ExecutionContext& context,
192- const framework::Tensor& a, bool trans_a,
193- bool is_fold_init_dims_a, const framework::Tensor& b,
191+ void CalcInputGrad (const framework::ExecutionContext & context,
192+ const framework::Tensor & a, bool trans_a,
193+ bool is_fold_init_dims_a, const framework::Tensor & b,
194194 bool trans_b, bool is_fold_init_dims_b,
195- framework::Tensor* out) const {
195+ framework::Tensor * out) const {
196196 if (out == nullptr ) return ;
197197 bool need_combine = (a.dims ().size () == 3 || b.dims ().size () == 3 ) &&
198198 out->dims ().size () == 2 ;
199199 if (!need_combine) {
200200 MatMul (context, a, trans_a, b, trans_b, out);
201201 } else {
202- auto & ctx = context.template device_context <DeviceContext>();
202+ auto & ctx = context.template device_context <DeviceContext>();
203203 MatMul (context, is_fold_init_dims_a
204204 ? FoldInitDims (a)
205205 : FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
210210 }
211211 }
212212
213- void Compute (const framework::ExecutionContext& context) const override {
213+ void Compute (const framework::ExecutionContext & context) const override {
214214 auto x = *context.Input <framework::Tensor>(" X" );
215215 auto y = *context.Input <framework::Tensor>(" Y" );
216216 auto dout =
217217 *context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
218- auto * dx = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
219- auto * dy = context.Output <framework::Tensor>(framework::GradVarName (" Y" ));
218+ auto * dx = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
219+ auto * dy = context.Output <framework::Tensor>(framework::GradVarName (" Y" ));
220220 bool transpose_x = context.Attr <bool >(" transpose_X" );
221221 bool transpose_y = context.Attr <bool >(" transpose_Y" );
222222
@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
269269 using framework::OperatorWithKernel::OperatorWithKernel;
270270
271271 protected:
272- void InferShape (framework::InferShapeContext* context) const override {
272+ void InferShape (framework::InferShapeContext * context) const override {
273273 PADDLE_ENFORCE (context->HasInput (" X" ),
274274 " Input(X) of MatMulOp should not be null." );
275275 PADDLE_ENFORCE (context->HasInput (" Y" ),
@@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
375375 using framework::OperatorWithKernel::OperatorWithKernel;
376376
377377 protected:
378- void InferShape (framework::InferShapeContext* context) const override {
378+ void InferShape (framework::InferShapeContext * context) const override {
379379 PADDLE_ENFORCE (context->HasInput (" X" ), " Input(X) should not be null" );
380380 PADDLE_ENFORCE (context->HasInput (" Y" ), " Input(Y) should not be null" );
381381 PADDLE_ENFORCE (context->HasInput (framework::GradVarName (" Out" )),
@@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
401401
402402 protected:
403403 std::unique_ptr<framework::OpDesc> Apply () const override {
404- auto * retv = new framework::OpDesc ();
404+ auto * retv = new framework::OpDesc ();
405405 retv->SetType (" matmul_grad" );
406406 retv->SetInput (" X" , Input (" X" ));
407407 retv->SetInput (" Y" , Input (" Y" ));
@@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
420420 ops::MatMulOpGradMaker);
421421REGISTER_OPERATOR (matmul_grad, ops::MatMulOpGrad);
422422REGISTER_OP_CPU_KERNEL (
423- matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float >);
423+ matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float >,
424+ ops::MatMulKernel<paddle::platform::CPUDeviceContext, double >,
425+ ops::MatMulKernel<paddle::platform::CPUDeviceContext,
426+ paddle::platform::float16>);
424427REGISTER_OP_CPU_KERNEL (
425428 matmul_grad,
426- ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float >);
429+ ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float >,
430+ ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double >,
431+ ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
432+ paddle::platform::float16>);
427433
428434#ifdef PADDLE_WITH_CUDA
429435REGISTER_OP_CUDA_KERNEL (
430- matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float >);
436+ matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float >,
437+ ops::MatMulKernel<paddle::platform::CUDADeviceContext, double >,
438+ ops::MatMulKernel<paddle::platform::CUDADeviceContext,
439+ paddle::platform::float16>);
431440REGISTER_OP_CUDA_KERNEL (
432441 matmul_grad,
433- ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float >);
442+ ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float >,
443+ ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double >,
444+ ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
445+ paddle::platform::float16>);
434446#endif
0 commit comments