@@ -120,12 +120,10 @@ static void SetOutMemDescWithUnsqueeze2FuseSupport(
120120 const std::vector<int64_t >& op_tz = out_md.dims ();
121121 std::vector<int64_t > unsqueezed_op_tz (
122122 op_tz.size () + fused_unsqueeze2_axes.size (), 0 );
123-
124123 for (const auto & axis : fused_unsqueeze2_axes) {
125124 int positive_axis = axis < 0 ? unsqueezed_op_tz.size () + axis : axis;
126125 unsqueezed_op_tz[positive_axis] = 1 ;
127126 }
128-
129127 int j = 0 ;
130128 for (size_t i = 0 ; i < unsqueezed_op_tz.size (); ++i) {
131129 if (unsqueezed_op_tz[i] == 0 ) {
@@ -143,20 +141,17 @@ static void SetOutMemDescWithReshape2FuseSupport(
143141 std::vector<int64_t > fused_reshape2_shape (
144142 ctx.Attr <std::vector<int >>(" fused_reshape2_shape" ).begin (),
145143 ctx.Attr <std::vector<int >>(" fused_reshape2_shape" ).end ());
146-
147144 const int out_shape_numel = out->numel ();
148145 const int new_shape_numel = std::accumulate (fused_reshape2_shape.begin (),
149146 fused_reshape2_shape.end (),
150147 1 ,
151148 std::multiplies<int64_t >());
152-
153149 for (size_t i = 0 ; i < fused_reshape2_shape.size (); ++i) {
154150 if (fused_reshape2_shape[i] == -1 ) {
155151 fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
156152 break ;
157153 }
158154 }
159-
160155 out->set_mem_desc (out_md.reshape (fused_reshape2_shape));
161156 out->Resize (phi::make_ddim (fused_reshape2_shape));
162157}
@@ -169,11 +164,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport(
169164 SetOutMemDescWithUnsqueeze2FuseSupport (ctx, out, out_md);
170165 } else if (ctx.HasAttr (" fused_reshape2_shape" )) {
171166 SetOutMemDescWithReshape2FuseSupport (ctx, out, out_md);
167+ } else if (ctx.HasAttr (" fused_squeeze2_axes" )) {
168+ out->set_mem_desc (out_md);
169+ out->Resize (phi::make_ddim (out_md.dims ()));
172170 } else {
173171 out->set_mem_desc (out_md);
174172 }
175173}
176174
175+ static void SetInMemDescWithSqueeze2FuseSupport (
176+ const framework::ExecutionContext& ctx,
177+ phi::DenseTensor* in,
178+ const dnnl::memory::desc& in_md) {
179+ const std::vector<int > fused_squeeze2_axes =
180+ ctx.Attr <std::vector<int >>(" fused_squeeze2_axes" );
181+ const std::set<int64_t > squeeze2_axes_set (fused_squeeze2_axes.begin (),
182+ fused_squeeze2_axes.end ());
183+ const std::vector<int64_t >& x_vec_dims = in_md.dims ();
184+ std::vector<int64_t > squeezed_op_tz (
185+ x_vec_dims.size () - fused_squeeze2_axes.size (), 0 );
186+
187+ int j = 0 ;
188+ for (size_t i = 0 ; i < x_vec_dims.size (); ++i) {
189+ if (squeeze2_axes_set.count (i) ||
190+ squeeze2_axes_set.count (i - x_vec_dims.size ())) {
191+ PADDLE_ENFORCE_EQ (
192+ x_vec_dims[i],
193+ 1 ,
194+ platform::errors::InvalidArgument (
195+ " Squeeze2 input '%d' dim should be equal to one, but get '%d'." ,
196+ i,
197+ x_vec_dims[i]));
198+ continue ;
199+ }
200+ squeezed_op_tz[j++] = x_vec_dims[i];
201+ }
202+
203+ in->set_mem_desc (in_md.reshape (squeezed_op_tz));
204+ in->Resize (phi::make_ddim (squeezed_op_tz));
205+ }
206+
207+ static void SetInMemDescWithLogicalLayoutFusesSupport (
208+ const framework::ExecutionContext& ctx,
209+ phi::DenseTensor* in,
210+ const dnnl::memory::desc& in_md) {
211+ if (ctx.HasAttr (" fused_squeeze2_axes" )) {
212+ SetInMemDescWithSqueeze2FuseSupport (ctx, in, in_md);
213+ } else {
214+ in->set_mem_desc (in_md);
215+ in->Resize (phi::make_ddim (in_md.dims ()));
216+ }
217+ }
218+
177219template <typename T>
178220constexpr bool IsInt8 () {
179221 return std::is_same<T, int8_t >::value || std::is_same<T, uint8_t >::value;
0 commit comments