@@ -24,6 +24,8 @@ limitations under the License. */
2424namespace  paddle  {
2525namespace  operators  {
2626
27+ using  framework::Tensor;
28+ 
2729static  framework::DDim RowMatrixFromVector (const  framework::DDim &x_dim) {
2830  if  (x_dim.size () > 1 ) {
2931    return  x_dim;
@@ -97,6 +99,86 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
9799  ReshapeTensorIntoMatrixSequence (y, mat_dim_y);
98100}
99101
102+ template  <typename  T, typename  FCT>
103+ static  void  MatMulXPUFunction (const  Tensor *x, const  Tensor *y, Tensor *out,
104+                               bool  trans_x, bool  trans_y,
105+                               const  paddle::framework::ExecutionContext &ctx) {
106+   const  auto  &x_dims = x->dims ();
107+   const  auto  &y_dims = y->dims ();
108+   auto  &dev_ctx =
109+       ctx.template  device_context <paddle::platform::XPUDeviceContext>();
110+ 
111+   auto  mat_dim_a =
112+       math::CreateMatrixDescriptor (RowMatrixFromVector (x_dims), 0 , trans_x);
113+   auto  mat_dim_b =
114+       math::CreateMatrixDescriptor (ColumnMatrixFromVector (y_dims), 0 , trans_y);
115+ 
116+   if  (x_dims.size () == 3  && y_dims.size () <= 2 ) {
117+     //  if transpose_X is true, the transpose cost much time
118+     if  (!trans_x) {
119+       mat_dim_a.height_  *= mat_dim_a.batch_size_ ;
120+       mat_dim_a.batch_size_  = 0 ;
121+     } else  {
122+       mat_dim_b.batch_size_  = mat_dim_a.batch_size_ ;
123+       mat_dim_b.height_  = mat_dim_b.height_  / mat_dim_b.batch_size_ ;
124+     }
125+   }
126+   PADDLE_ENFORCE_EQ (
127+       mat_dim_a.width_ , mat_dim_b.height_ ,
128+       platform::errors::InvalidArgument (" Shape mistake in matmul_op, the " 
129+                                         " first tensor width must be same as " 
130+                                         " second tensor height, but received " 
131+                                         " width:%d, height:%d" 
132+                                         mat_dim_a.width_ , mat_dim_b.height_ ));
133+   PADDLE_ENFORCE_EQ (mat_dim_a.batch_size_ , mat_dim_b.batch_size_ ,
134+                     platform::errors::InvalidArgument (
135+                         " Shape mistake in matmul_op, the two input" 
136+                         " tensor batch_size must be same, but received first " 
137+                         " tensor batch_size:%d, second " 
138+                         " tensor batch_size:%d" 
139+                         mat_dim_a.batch_size_ , mat_dim_b.batch_size_ ));
140+ 
141+   T alpha = static_cast <T>(ctx.Attr <float >(" alpha" 
142+ 
143+   float  *data_c = out->data <T>();
144+   int  m = mat_dim_a.height_ ;
145+   int  n = mat_dim_b.width_ ;
146+   int  k = mat_dim_a.width_ ;
147+   int  ldx = mat_dim_a.trans_  ? m : k;
148+   int  ldy = mat_dim_b.trans_  ? k : n;
149+   int  ldout = n;
150+   int  batch_size = mat_dim_a.batch_size_ ;
151+ 
152+   if  (batch_size == 0 ) {
153+     int  r = xpu::fc_fusion<float , float , float , FCT>(
154+         dev_ctx.x_context (), x->data <T>(), y->data <T>(), data_c, m, n, k,
155+         mat_dim_a.trans_ , mat_dim_b.trans_ , nullptr , nullptr , nullptr , ldx, ldy,
156+         ldout, alpha, 0 , nullptr , xpu::Activation_t::LINEAR);
157+     PADDLE_ENFORCE_EQ (r, XPU_SUCCESS,
158+                       platform::errors::External (
159+                           " XPU fc_fusion kernel return wrong value[%d %s]" 
160+                           XPUAPIErrorMsg[r]));
161+   } else  {
162+     //  batch matmul
163+     int  x_stride = mat_dim_a.stride_ ;
164+     int  y_stride = mat_dim_b.stride_ ;
165+     int  out_stride = m * n;
166+     for  (int  i = 0 ; i < batch_size; ++i) {
167+       const  float  *x_data = x->data <T>() + x_stride * i;
168+       const  float  *y_data = y->data <T>() + y_stride * i;
169+       float  *out_data = data_c + out_stride * i;
170+       int  r = xpu::fc_fusion<float , float , float , FCT>(
171+           dev_ctx.x_context (), x_data, y_data, out_data, m, n, k,
172+           mat_dim_a.trans_ , mat_dim_b.trans_ , nullptr , nullptr , nullptr , ldx,
173+           ldy, ldout, alpha, 0 , nullptr , xpu::Activation_t::LINEAR);
174+       PADDLE_ENFORCE_EQ (r, XPU_SUCCESS,
175+                         platform::errors::External (
176+                             " XPU fc_fusion kernel return wrong value[%d %s]" 
177+                             XPUAPIErrorMsg[r]));
178+     }
179+   }
180+ }
181+ 
100182template  <typename  DeviceContext, typename  T>
101183class  MatMulXPUKernel  : public  framework ::OpKernel<T> {
102184 public: 
@@ -105,78 +187,12 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
105187    auto  *y = context.Input <framework::Tensor>(" Y" 
106188    auto  *out = context.Output <framework::Tensor>(" Out" 
107189    out->mutable_data <T>(context.GetPlace ());
108- 
109-     auto  mat_dim_a = math::CreateMatrixDescriptor (
110-         RowMatrixFromVector (x->dims ()), 0 , context.Attr <bool >(" transpose_X" 
111-     auto  mat_dim_b =
112-         math::CreateMatrixDescriptor (ColumnMatrixFromVector (y->dims ()), 0 ,
113-                                      context.Attr <bool >(" transpose_Y" 
114- 
115-     const  auto  &x_dims = x->dims ();
116-     const  auto  &y_dims = y->dims ();
117-     if  (x_dims.size () == 3  && y_dims.size () <= 2 ) {
118-       //  if transpose_X is true, the transpose cost much time
119-       if  (!context.Attr <bool >(" transpose_X" 
120-         mat_dim_a.height_  *= mat_dim_a.batch_size_ ;
121-         mat_dim_a.batch_size_  = 0 ;
122-       } else  {
123-         mat_dim_b.batch_size_  = mat_dim_a.batch_size_ ;
124-         mat_dim_b.height_  = mat_dim_b.height_  / mat_dim_b.batch_size_ ;
125-       }
126-     }
127- 
128-     PADDLE_ENFORCE_EQ (
129-         mat_dim_a.width_ , mat_dim_b.height_ ,
130-         platform::errors::InvalidArgument (" Shape mistake in matmul_op, the " 
131-                                           " first tensor width must be same as " 
132-                                           " second tensor height, but received " 
133-                                           " width:%d, height:%d" 
134-                                           mat_dim_a.width_ , mat_dim_b.height_ ));
135-     PADDLE_ENFORCE_EQ (mat_dim_a.batch_size_ , mat_dim_b.batch_size_ ,
136-                       platform::errors::InvalidArgument (
137-                           " Shape mistake in matmul_op, the two input" 
138-                           " tensor batch_size must be same, but received first " 
139-                           " tensor batch_size:%d, second " 
140-                           " tensor batch_size:%d" 
141-                           mat_dim_a.batch_size_ , mat_dim_b.batch_size_ ));
142-     T alpha = static_cast <T>(context.Attr <float >(" alpha" 
143- 
144-     auto  &dev_ctx = context.template  device_context <DeviceContext>();
145-     float  *data_c = out->data <T>();
146-     int  m = mat_dim_a.height_ ;
147-     int  n = mat_dim_b.width_ ;
148-     int  k = mat_dim_a.width_ ;
149-     int  ldx = mat_dim_a.trans_  ? m : k;
150-     int  ldy = mat_dim_b.trans_  ? k : n;
151-     int  ldout = n;
152-     int  batch_size = mat_dim_a.batch_size_ ;
153-     if  (batch_size == 0  || batch_size == 1 ) {
154-       int  r = xpu::fc_fusion<float , float , float , int16_t >(
155-           dev_ctx.x_context (), x->data <T>(), y->data <T>(), data_c, m, n, k,
156-           mat_dim_a.trans_ , mat_dim_b.trans_ , nullptr , nullptr , nullptr , ldx,
157-           ldy, ldout, alpha, 0 , nullptr , xpu::Activation_t::LINEAR);
158-       PADDLE_ENFORCE_EQ (r, XPU_SUCCESS,
159-                         platform::errors::External (
160-                             " XPU fc_fusion kernel return wrong value[%d %s]" 
161-                             XPUAPIErrorMsg[r]));
190+     bool  trans_x = context.Attr <bool >(" transpose_X" 
191+     bool  trans_y = context.Attr <bool >(" transpose_Y" 
192+     if  (std::getenv (" XPU_PADDLE_MAT_MUL_FCINT32" nullptr ) {
193+       MatMulXPUFunction<T, int32_t >(x, y, out, trans_x, trans_y, context);
162194    } else  {
163-       //  batch matmul
164-       int  x_stride = mat_dim_a.stride_ ;
165-       int  y_stride = mat_dim_b.stride_ ;
166-       int  out_stride = m * n;
167-       for  (int  i = 0 ; i < batch_size; ++i) {
168-         const  float  *x_data = x->data <T>() + x_stride * i;
169-         const  float  *y_data = y->data <T>() + y_stride * i;
170-         float  *out_data = data_c + out_stride * i;
171-         int  r = xpu::fc_fusion<float , float , float , int16_t >(
172-             dev_ctx.x_context (), x_data, y_data, out_data, m, n, k,
173-             mat_dim_a.trans_ , mat_dim_b.trans_ , nullptr , nullptr , nullptr , ldx,
174-             ldy, ldout, alpha, 0 , nullptr , xpu::Activation_t::LINEAR);
175-         PADDLE_ENFORCE_EQ (r, XPU_SUCCESS,
176-                           platform::errors::External (
177-                               " XPU fc_fusion kernel return wrong value[%d %s]" 
178-                               r, XPUAPIErrorMsg[r]));
179-       }
195+       MatMulXPUFunction<T, int16_t >(x, y, out, trans_x, trans_y, context);
180196    }
181197  }
182198};
@@ -244,75 +260,10 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
244260              const  framework::Tensor &b, bool  trans_b,
245261              framework::Tensor *out) const  {
246262    out->mutable_data <T>(context.GetPlace ());
247-     auto  mat_dim_a = math::CreateMatrixDescriptor (a.dims (), 0 , trans_a);
248-     auto  mat_dim_b = math::CreateMatrixDescriptor (b.dims (), 0 , trans_b);
249-     const  auto  &a_dims = a.dims ();
250-     const  auto  &b_dims = b.dims ();
251-     if  (a_dims.size () == 3  && b_dims.size () <= 2 ) {
252-       //  if transpose_X is true, the transpose cost much time
253-       if  (!context.Attr <bool >(" transpose_X" 
254-         mat_dim_a.height_  *= mat_dim_a.batch_size_ ;
255-         mat_dim_a.batch_size_  = 0 ;
256-       } else  {
257-         mat_dim_b.batch_size_  = mat_dim_a.batch_size_ ;
258-         mat_dim_b.height_  = mat_dim_b.height_  / mat_dim_b.batch_size_ ;
259-       }
260-     }
261- 
262-     PADDLE_ENFORCE_EQ (mat_dim_a.width_ , mat_dim_b.height_ ,
263-                       platform::errors::InvalidArgument (
264-                           " Shape mistake in matmul_grad_op, the " 
265-                           " first tensor width must be same as second tensor " 
266-                           " height, but received " 
267-                           " width:%d, height:%d" 
268-                           mat_dim_a.width_ , mat_dim_b.height_ ));
269-     PADDLE_ENFORCE_EQ (mat_dim_a.batch_size_ , mat_dim_b.batch_size_ ,
270-                       platform::errors::InvalidArgument (
271-                           " Shape mistake in matmul_grad_op, the two input" 
272-                           " tensor batch_size must be same, but received first " 
273-                           " tensor batch_size:%d, second " 
274-                           " tensor batch_size:%d" 
275-                           mat_dim_a.batch_size_ , mat_dim_b.batch_size_ ));
276- 
277-     T alpha = static_cast <T>(context.Attr <float >(" alpha" 
278- 
279-     auto  &dev_ctx = context.template  device_context <DeviceContext>();
280-     float  *data_c = out->data <T>();
281- 
282-     int  m = mat_dim_a.height_ ;
283-     int  n = mat_dim_b.width_ ;
284-     int  k = mat_dim_a.width_ ;
285-     int  ldx = mat_dim_a.trans_  ? m : k;
286-     int  ldy = mat_dim_b.trans_  ? k : n;
287-     int  ldout = n;
288-     int  batch_size = mat_dim_a.batch_size_ ;
289-     if  (batch_size == 0  || batch_size == 1 ) {
290-       int  r = xpu::fc_fusion<float , float , float , int16_t >(
291-           dev_ctx.x_context (), a.data <T>(), b.data <T>(), data_c, m, n, k,
292-           mat_dim_a.trans_ , mat_dim_b.trans_ , nullptr , nullptr , nullptr , ldx,
293-           ldy, ldout, alpha, 0 , nullptr , xpu::Activation_t::LINEAR);
294-       PADDLE_ENFORCE_EQ (r, XPU_SUCCESS,
295-                         platform::errors::External (
296-                             " XPU fc_fusion kernel return wrong value[%d %s]" 
297-                             XPUAPIErrorMsg[r]));
263+     if  (std::getenv (" XPU_PADDLE_MAT_MUL_GRAD_FCINT32" nullptr ) {
264+       MatMulXPUFunction<T, int32_t >(&a, &b, out, trans_a, trans_b, context);
298265    } else  {
299-       //  batch matmul
300-       int  x_stride = mat_dim_a.stride_ ;
301-       int  y_stride = mat_dim_b.stride_ ;
302-       int  out_stride = m * n;
303-       for  (int  i = 0 ; i < batch_size; ++i) {
304-         const  float  *x_data = a.data <T>() + x_stride * i;
305-         const  float  *y_data = b.data <T>() + y_stride * i;
306-         float  *out_data = data_c + out_stride * i;
307-         int  r = xpu::fc_fusion<float , float , float , int16_t >(
308-             dev_ctx.x_context (), x_data, y_data, out_data, m, n, k,
309-             mat_dim_a.trans_ , mat_dim_b.trans_ , nullptr , nullptr , nullptr , ldx,
310-             ldy, ldout, alpha, 0 , nullptr , xpu::Activation_t::LINEAR);
311-         PADDLE_ENFORCE_EQ (r, XPU_SUCCESS,
312-                           platform::errors::External (
313-                               " XPU fc_fusion kernel return wrong value[%d %s]" 
314-                               r, XPUAPIErrorMsg[r]));
315-       }
266+       MatMulXPUFunction<T, int16_t >(&a, &b, out, trans_a, trans_b, context);
316267    }
317268  }
318269
0 commit comments