Skip to content

Commit 2a26f80

Browse files
authored
【Infer Symbolic Shape No.122】【BUAA】Add batch_fc (#67348)
* batch_fc * add constrain
1 parent 4012436 commit 2a26f80

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,41 @@ bool AucOpInferSymbolicShape(pir::Operation *op,
228228
return true;
229229
}
230230

231-
// bool BatchFcOpInferSymbolicShape(pir::Operation *op,
232-
// pir::InferSymbolicShapeContext
233-
// *infer_context) {
234-
// // pass
235-
// return true;
236-
// }
231+
bool BatchFcOpInferSymbolicShape(
232+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
233+
const auto &input_shape_or_data =
234+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
235+
const auto &w_shape_or_data =
236+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
237+
const auto &bias_shape_or_data =
238+
infer_context->GetShapeOrDataForValue(op->operand_source(2));
239+
240+
const std::vector<symbol::DimExpr> &input_dims = input_shape_or_data.shape();
241+
const std::vector<symbol::DimExpr> &w_dims = w_shape_or_data.shape();
242+
const std::vector<symbol::DimExpr> &bias_dims = bias_shape_or_data.shape();
243+
244+
PADDLE_ENFORCE_EQ(
245+
input_dims.size(),
246+
3,
247+
common::errors::InvalidArgument("Input of BatchFcOp should have 3D."));
248+
PADDLE_ENFORCE_EQ(
249+
w_dims.size(),
250+
3,
251+
common::errors::InvalidArgument("W of BatchFcOp should have 3D."));
252+
infer_context->AddEqualCstr(input_dims[0], w_dims[0]);
253+
infer_context->AddEqualCstr(input_dims[2], w_dims[1]);
254+
infer_context->AddEqualCstr(bias_dims[0], input_dims[0]);
255+
infer_context->AddEqualCstr(bias_dims[1], w_dims[2]);
256+
257+
std::vector<symbol::DimExpr> out_dims = {
258+
input_dims[0], input_dims[1], w_dims[2]};
259+
260+
infer_context->SetShapeOrDataForValue(
261+
op->result(0),
262+
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});
263+
264+
return true;
265+
}
237266

238267
bool BatchNormOpInferSymbolicShape(
239268
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN)
2525
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc)
2626
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos)
2727
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors)
28-
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc)
28+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc)
2929
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm)
3030
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_)
3131
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@
466466
func : batch_fc
467467
data_type: input
468468
backward: batch_fc_grad
469+
interfaces : paddle::dialect::InferSymbolicShapeInterface
469470

470471
- op : bce_loss
471472
args : (Tensor input, Tensor label)

0 commit comments

Comments
 (0)