Skip to content

Commit d678aba

Browse files
【Infer Symbolic Shape BUAA No.42-44】Add ops (#66930)
1 parent efdd967 commit d678aba

File tree

5 files changed

+78
-0
lines changed

5 files changed

+78
-0
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,51 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op,
782782
return true;
783783
}
784784

785+
bool FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(
786+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
787+
const symbol::ShapeOrDataDimExprs &x_shape =
788+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
789+
790+
// Validate the bit_length attribute
791+
int bit_length = op->attribute<pir::Int32Attribute>("bit_length").data();
792+
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
793+
true,
794+
phi::errors::InvalidArgument(
795+
"'bit_length' should be between 1 and 16, but "
796+
"the received is %d",
797+
bit_length));
798+
799+
// Set the shape for the output tensor 'out', same as input tensor 'x'
800+
infer_context->SetShapeOrDataForValue(op->result(0), x_shape);
801+
802+
// Create a scalar shape for the other output tensors
803+
symbol::TensorShapeOrDataDimExprs scalar_shape(
804+
std::vector<symbol::DimExpr>{symbol::DimExpr(1)});
805+
806+
// Set the shape for all scalar output tensors: 'out_scale', 'out_state',
807+
// 'out_accum'
808+
for (size_t i = 1; i < op->num_results(); ++i) {
809+
infer_context->SetShapeOrDataForValue(op->result(i), scalar_shape);
810+
}
811+
812+
return true;
813+
}
814+
815+
bool FakeQuantizeMovingAverageAbsMax_OpInferSymbolicShape(
816+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
817+
return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context);
818+
}
819+
820+
bool FakeQuantizeDequantizeMovingAverageAbsMaxOpInferSymbolicShape(
821+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
822+
return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context);
823+
}
824+
825+
bool FakeQuantizeDequantizeMovingAverageAbsMax_OpInferSymbolicShape(
826+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
827+
return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context);
828+
}
829+
785830
bool FullWithTensorOpInferSymbolicShape(
786831
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
787832
pir::Value operand_source = op->operand_source(1);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
3535
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
3636
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax)
3737
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_)
38+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeMovingAverageAbsMax)
39+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeMovingAverageAbsMax_)
40+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax)
41+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax_)
3842
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CoalesceTensor)
3943
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CoalesceTensor_)
4044
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,31 @@ bool FlattenOpInferSymbolicShape(
943943
return true;
944944
}
945945

946+
bool FakeQuantizeDequantizeAbsMaxOpInferSymbolicShape(
947+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
948+
const symbol::ShapeOrDataDimExprs &x_shape =
949+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
950+
951+
// Validate the bit_length attribute
952+
int bit_length = op->attribute<pir::Int32Attribute>("bit_length").data();
953+
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
954+
true,
955+
phi::errors::InvalidArgument(
956+
"'bit_length' should be between 1 and 16, but "
957+
"the received is %d",
958+
bit_length));
959+
960+
// Set the shape for the output tensor 'out', same as input tensor 'x'
961+
infer_context->SetShapeOrDataForValue(op->result(0), x_shape);
962+
963+
// Set the shape for the output tensor 'out_scale' as a scalar {1}
964+
symbol::TensorShapeOrDataDimExprs scalar_shape(
965+
std::vector<symbol::DimExpr>{symbol::DimExpr(1)});
966+
infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape);
967+
968+
return true;
969+
}
970+
946971
bool Flatten_OpInferSymbolicShape(
947972
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
948973
return FlattenOpInferSymbolicShape(op, infer_context);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal)
6262
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal_)
6363
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten)
6464
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten_)
65+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeAbsMax)
6566
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Inverse)
6667
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GumbelSoftmax)
6768
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IdentityLoss)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,6 +1715,7 @@
17151715
func : fake_quantize_dequantize_abs_max
17161716
data_type : x
17171717
backward : fake_quantize_dequantize_abs_max_grad
1718+
interfaces : paddle::dialect::InferSymbolicShapeInterface
17181719

17191720
- op : fake_quantize_dequantize_moving_average_abs_max
17201721
args : (Tensor x, Tensor in_scale, Tensor in_accum, Tensor in_state, float moving_rate = 0.9, int bit_length = 8, bool is_test = false, int round_type = 1)
@@ -1727,6 +1728,7 @@
17271728
optional : in_accum, in_state, out_state, out_accum
17281729
backward : fake_quantize_dequantize_moving_average_abs_max_grad
17291730
inplace: (in_scale -> out_scale)
1731+
interfaces : paddle::dialect::InferSymbolicShapeInterface
17301732

17311733
- op : fake_quantize_moving_average_abs_max
17321734
args : (Tensor x, Tensor in_scale, Tensor in_accum, Tensor in_state, float moving_rate = 0.9, int bit_length = 8, bool is_test = false, int round_type = 1)
@@ -1738,6 +1740,7 @@
17381740
data_type : x
17391741
optional : in_accum, in_state, out_state, out_accum
17401742
inplace: (in_scale -> out_scale)
1743+
interfaces : paddle::dialect::InferSymbolicShapeInterface
17411744

17421745
- op : fake_quantize_range_abs_max
17431746
args : (Tensor x, Tensor in_scale, Tensor iter, int window_size = 10000, int bit_length = 8, bool is_test = false, int round_type = 1)

0 commit comments

Comments
 (0)