Skip to content

Commit 7232797

Browse files
authored
【Infer Symbolic Shape BUAA No.10 No.73 No.78】Add mode , margin_cross_entropy , box_clip. op (#66730)
* mode * boxclip * test_op * testop * margin_cross_entropy_op * resolve conflict * resolve conflict * resolve conflict * testop * testop * testop * tensor * fix fix conflict fix conflict fix conflict * fix * test * fix problem * fix equal * Update unused viriable * static check * rerun * add cstr * fix small problem * equal constrain * change change change * codestyle * fix error * comment fix * fix comment * fix
1 parent ed69456 commit 7232797

File tree

5 files changed

+109
-0
lines changed

5 files changed

+109
-0
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,33 @@ bool AllcloseOpInferSymbolicShape(
9191
return true;
9292
}
9393

94+
bool BoxClipOpInferSymbolicShape(
95+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
96+
const auto &input_shape =
97+
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
98+
const auto &im_info_shape =
99+
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();
100+
101+
// Check rank and dimensions of input tensors
102+
const auto &three = symbol::DimExpr{3};
103+
const auto &four = symbol::DimExpr{4};
104+
infer_context->AddEqualCstr(input_shape[input_shape.size() - 1], four);
105+
PADDLE_ENFORCE_EQ(im_info_shape.size(),
106+
2,
107+
common::errors::InvalidArgument(
108+
"The rank of Input(im_info) in BoxClipOp must be 2. "
109+
"But received rank = %d",
110+
im_info_shape.size()));
111+
infer_context->AddEqualCstr(im_info_shape[1], three);
112+
113+
infer_context->SetShapeOrDataForValue(
114+
op->result(0),
115+
symbol::ShapeOrDataDimExprs{
116+
symbol::TensorShapeOrDataDimExprs(input_shape)});
117+
118+
return true;
119+
}
120+
94121
bool Atan2OpInferSymbolicShape(pir::Operation *op,
95122
pir::InferSymbolicShapeContext *infer_context) {
96123
const auto x_shape =
@@ -781,6 +808,46 @@ bool MatmulOpInferSymbolicShape(pir::Operation *op,
781808
return true;
782809
}
783810

811+
bool MarginCrossEntropyOpInferSymbolicShape(
812+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
813+
const auto &logits_shape_or_data =
814+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
815+
const auto &labels_shape_or_data =
816+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
817+
818+
std::vector<symbol::DimExpr> logits_dims = logits_shape_or_data.shape();
819+
std::vector<symbol::DimExpr> labels_dims = labels_shape_or_data.shape();
820+
821+
size_t logits_rank = logits_dims.size();
822+
auto axis = logits_rank - 1;
823+
824+
for (size_t i = 0; i < logits_rank; i++) {
825+
if (i != axis) {
826+
infer_context->AddEqualCstr(logits_dims[i], labels_dims[i]);
827+
}
828+
}
829+
830+
const auto &one = symbol::DimExpr{1};
831+
832+
if (labels_dims.size() > 1) {
833+
infer_context->AddEqualCstr(labels_dims[axis - 1], one);
834+
}
835+
836+
infer_context->SetShapeOrDataForValue(
837+
op->result(0),
838+
symbol::ShapeOrDataDimExprs{
839+
symbol::TensorShapeOrDataDimExprs(logits_dims)});
840+
841+
logits_dims[axis] = symbol::DimExpr(1);
842+
843+
infer_context->SetShapeOrDataForValue(
844+
op->result(1),
845+
symbol::ShapeOrDataDimExprs{
846+
symbol::TensorShapeOrDataDimExprs(logits_dims)});
847+
848+
return true;
849+
}
850+
784851
bool MvOpInferSymbolicShape(pir::Operation *op,
785852
pir::InferSymbolicShapeContext *infer_context) {
786853
const auto &x_shape_or_data =

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Allclose)
2222
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atan2)
2323
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss)
2424
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_)
25+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxClip)
2526
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial)
2627
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_)
2728
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
@@ -53,6 +54,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
5354
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
5455
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
5556
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
57+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MarginCrossEntropy)
5658
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mv)
5759
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullBoxSparse)
5860
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullGpuPsSparse)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,42 @@ bool MaxOpInferSymbolicShape(pir::Operation *op,
11461146

11471147
return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all);
11481148
}
1149+
bool ModeOpInferSymbolicShape(pir::Operation *op,
1150+
pir::InferSymbolicShapeContext *infer_context) {
1151+
const auto &x_shape_or_data =
1152+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
1153+
const auto &x_shape = x_shape_or_data.shape();
1154+
1155+
int axis = op->attribute<pir::Int32Attribute>("axis").data();
1156+
bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data();
1157+
1158+
int dim_size = x_shape.size();
1159+
1160+
if (axis < 0) {
1161+
axis += dim_size;
1162+
}
1163+
1164+
std::vector<symbol::DimExpr> out_dims;
1165+
for (int i = 0; i < axis; i++) {
1166+
out_dims.emplace_back(x_shape[i]);
1167+
}
1168+
if (keepdim && dim_size > 0) {
1169+
out_dims.emplace_back(symbol::DimExpr(1));
1170+
}
1171+
for (int i = axis + 1; i < dim_size; i++) {
1172+
out_dims.emplace_back(x_shape[i]);
1173+
}
1174+
1175+
symbol::TensorShapeOrDataDimExprs out_shape(out_dims);
1176+
1177+
infer_context->SetShapeOrDataForValue(op->result(0),
1178+
symbol::ShapeOrDataDimExprs{out_shape});
1179+
1180+
infer_context->SetShapeOrDataForValue(op->result(1),
1181+
symbol::ShapeOrDataDimExprs{out_shape});
1182+
1183+
return true;
1184+
}
11491185

11501186
bool MaxoutOpInferSymbolicShape(pir::Operation *op,
11511187
pir::InferSymbolicShapeContext *infer_context) {

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_)
7474
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d)
7575
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
7676
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
77+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode)
7778
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
7879
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout)
7980
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@
652652
func: BoxClipInferMeta
653653
kernel:
654654
func: box_clip
655+
interfaces: paddle::dialect::InferSymbolicShapeInterface
655656

656657
- op : box_coder
657658
args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type = "encode_center_size", bool box_normalized = true, int axis = 0, float[] variance = {})
@@ -3019,6 +3020,7 @@
30193020
func : margin_cross_entropy
30203021
data_type : logits
30213022
backward : margin_cross_entropy_grad
3023+
interfaces : paddle::dialect::InferSymbolicShapeInterface
30223024

30233025
- op : masked_multihead_attention_
30243026
args : (Tensor x, Tensor cache_kv, Tensor bias, Tensor src_mask, Tensor cum_offsets, Tensor sequence_lengths, Tensor rotary_tensor, Tensor beam_cache_offset, Tensor qkv_out_scale, Tensor out_shift, Tensor out_smooth, int seq_len, int rotary_emb_dims, bool use_neox_rotary_style=false, str compute_dtype = "default", float out_scale=-1, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0)
@@ -3239,6 +3241,7 @@
32393241
kernel :
32403242
func : mode
32413243
backward : mode_grad
3244+
interfaces : paddle::dialect::InferSymbolicShapeInterface
32423245

32433246
- op : momentum_
32443247
args : (Tensor param, Tensor grad, Tensor velocity, Tensor learning_rate, Tensor master_param, float mu, bool use_nesterov = false, str regularization_method = "", float regularization_coeff = 0.0f, bool multi_precision = false, float rescale_grad = 1.0f)

0 commit comments

Comments
 (0)