@@ -228,12 +228,41 @@ bool AucOpInferSymbolicShape(pir::Operation *op,
228
228
return true ;
229
229
}
230
230
231
- // bool BatchFcOpInferSymbolicShape(pir::Operation *op,
232
- // pir::InferSymbolicShapeContext
233
- // *infer_context) {
234
- // // pass
235
- // return true;
236
- // }
231
+ bool BatchFcOpInferSymbolicShape (
232
+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
233
+ const auto &input_shape_or_data =
234
+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
235
+ const auto &w_shape_or_data =
236
+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
237
+ const auto &bias_shape_or_data =
238
+ infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
239
+
240
+ const std::vector<symbol::DimExpr> &input_dims = input_shape_or_data.shape ();
241
+ const std::vector<symbol::DimExpr> &w_dims = w_shape_or_data.shape ();
242
+ const std::vector<symbol::DimExpr> &bias_dims = bias_shape_or_data.shape ();
243
+
244
+ PADDLE_ENFORCE_EQ (
245
+ input_dims.size (),
246
+ 3 ,
247
+ common::errors::InvalidArgument (" Input of BatchFcOp should have 3D." ));
248
+ PADDLE_ENFORCE_EQ (
249
+ w_dims.size (),
250
+ 3 ,
251
+ common::errors::InvalidArgument (" W of BatchFcOp should have 3D." ));
252
+ infer_context->AddEqualCstr (input_dims[0 ], w_dims[0 ]);
253
+ infer_context->AddEqualCstr (input_dims[2 ], w_dims[1 ]);
254
+ infer_context->AddEqualCstr (bias_dims[0 ], input_dims[0 ]);
255
+ infer_context->AddEqualCstr (bias_dims[1 ], w_dims[2 ]);
256
+
257
+ std::vector<symbol::DimExpr> out_dims = {
258
+ input_dims[0 ], input_dims[1 ], w_dims[2 ]};
259
+
260
+ infer_context->SetShapeOrDataForValue (
261
+ op->result (0 ),
262
+ symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs (out_dims)});
263
+
264
+ return true ;
265
+ }
237
266
238
267
bool BatchNormOpInferSymbolicShape (
239
268
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
0 commit comments