@@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
570570
571571void  EinsumInferMeta (const  std::vector<const  MetaTensor*>& inputs,
572572                     const  std::string& equation,
573-                      MetaTensor* out,
574-                      std::vector<MetaTensor*> inner_cache,
575-                      std::vector<MetaTensor*> xshape) {
573+                      MetaTensor* out) {
576574  //  collect the following informations to prepare einsum.
577575  LabelMap labelshape (0 );
578576  LabelMap labeltype (LabelType::Reduction);
@@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
609607  VLOG (3 ) << " Label Shape is : "   << label_to_string (all_labels, labelshape);
610608  out->set_dims (make_ddim (output_dims));
611609  out->set_dtype (inputs[0 ]->dtype ());
610+ }
611+ 
612+ void  EinsumRawInferMeta (const  std::vector<const  MetaTensor*>& inputs,
613+                         const  std::string& equation,
614+                         MetaTensor* out,
615+                         std::vector<MetaTensor*> inner_cache,
616+                         std::vector<MetaTensor*> xshape) {
617+   EinsumInferMeta (inputs, equation, out);
612618  for  (size_t  i = 0 ; i < xshape.size (); ++i) {
613619    if  (xshape[i] != nullptr ) {
614620      xshape[i]->set_dims (inputs[i]->dims ());
@@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x,
24482454
24492455void  SqueezeInferMeta (const  MetaTensor& x,
24502456                      const  std::vector<int >& axes,
2451-                       MetaTensor* out,
2452-                       MetaTensor* xshape) {
2457+                       MetaTensor* out) {
24532458  const  auto & x_dims = x.dims ();
24542459  //  Check input tensor dims (<6) Eigen limit.
24552460  PADDLE_ENFORCE_LE (x_dims.size (),
@@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x,
24692474    out->share_lod (x);
24702475  }
24712476
2477+   out->set_dtype (x.dtype ());
2478+ }
2479+ 
2480+ void  SqueezeWithXShapeInferMeta (const  MetaTensor& x,
2481+                                 const  std::vector<int >& axes,
2482+                                 MetaTensor* out,
2483+                                 MetaTensor* xshape) {
2484+   SqueezeInferMeta (x, axes, out);
2485+   const  auto & x_dims = x.dims ();
24722486  std::vector<int64_t > xshape_dims (x_dims.size () + 1 );
24732487  xshape_dims[0 ] = 0 ;
24742488  for  (int  i = 0 ; i < x_dims.size (); ++i) {
24752489    xshape_dims[i + 1 ] = x_dims[i];
24762490  }
2477-   xshape->set_dims (phi::make_ddim (xshape_dims));
2478-   xshape->share_lod (x);
2479-   xshape->set_dtype (x.dtype ());
2480-   out->set_dtype (x.dtype ());
2491+   if  (xshape) {
2492+     xshape->set_dims (phi::make_ddim (xshape_dims));
2493+     xshape->share_lod (x);
2494+     xshape->set_dtype (x.dtype ());
2495+   }
24812496}
24822497
24832498void  StridedSliceRawInferMeta (const  MetaTensor& x,
@@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x,
33103325void  UnsqueezeInferMeta (const  MetaTensor& x,
33113326                        const  IntArray& axes,
33123327                        MetaTensor* out,
3313-                         MetaTensor* xshape,
33143328                        MetaConfig config) {
33153329  const  auto & x_dims = x.dims ();
33163330  //  Validity Check: input tensor dims (<6).
@@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x,
33393353    }
33403354    out->set_dtype (x.dtype ());
33413355  }
3342-   if  (xshape) {
3343-     //  set xshape dims.
3344-     std::vector<int64_t > xshape_dims (x_dims.size () + 1 );
3345-     xshape_dims[0 ] = 0 ;
3346-     for  (int  i = 0 ; i < x_dims.size (); ++i) {
3347-       xshape_dims[i + 1 ] = x_dims[i];
3348-     }
3356+ }
33493357
3358+ void  UnsqueezeWithXShapeInferMeta (const  MetaTensor& x,
3359+                                   const  IntArray& axes,
3360+                                   MetaTensor* out,
3361+                                   MetaTensor* xshape,
3362+                                   MetaConfig config) {
3363+   const  auto & x_dims = x.dims ();
3364+   UnsqueezeInferMeta (x, axes, out, config);
3365+   //  set xshape dims.
3366+   std::vector<int64_t > xshape_dims (x_dims.size () + 1 );
3367+   xshape_dims[0 ] = 0 ;
3368+   for  (int  i = 0 ; i < x_dims.size (); ++i) {
3369+     xshape_dims[i + 1 ] = x_dims[i];
3370+   }
3371+   if  (xshape) {
33503372    xshape->set_dims (phi::make_ddim (xshape_dims));
33513373    xshape->share_lod (x);
33523374    xshape->set_dtype (x.dtype ());
0 commit comments