Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,33 @@ bool AllcloseOpInferSymbolicShape(
return true;
}

bool BoxClipOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
const auto &im_info_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();

// Check rank and dimensions of input tensors
const auto &three = symbol::DimExpr{3};
const auto &four = symbol::DimExpr{4};
infer_context->AddEqualCstr(input_shape[input_shape.size() - 1], four);
PADDLE_ENFORCE_EQ(im_info_shape.size(),
2,
common::errors::InvalidArgument(
"The rank of Input(im_info) in BoxClipOp must be 2. "
"But received rank = %d",
im_info_shape.size()));
infer_context->AddEqualCstr(im_info_shape[1], three);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(input_shape)});

return true;
}

bool Atan2OpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto x_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Allclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atan2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxClip)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@
func: BoxClipInferMeta
kernel:
func: box_clip
interfaces: paddle::dialect::InferSymbolicShapeInterface

- op : box_coder
args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type = "encode_center_size", bool box_normalized = true, int axis = 0, float[] variance = {})
Expand Down