Skip to content

Commit 5c1cfc7

Browse files
authored
[Infer Symbolic Shape No.27,71][BUAA]edit_distance,kldiv_loss (#67327)
* Finished kldiv loss op * Finished kldiv loss op * Fixed name issues * Resolved comments in #67117
1 parent 7464d9e commit 5c1cfc7

File tree

4 files changed

+49
-3
lines changed

4 files changed

+49
-3
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,51 @@ bool IndexSampleOpInferSymbolicShape(
643643
return true;
644644
}
645645

646+
bool KldivLossOpInferSymbolicShape(
647+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
648+
const auto &x_shape_or_data =
649+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
650+
const auto &label_shape_or_data =
651+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
652+
const auto &x_shape = x_shape_or_data.shape();
653+
const auto &label_shape = label_shape_or_data.shape();
654+
655+
PADDLE_ENFORCE_EQ(x_shape.size(),
656+
label_shape.size(),
657+
common::errors::InvalidArgument(
658+
"Input(X) rank and Input(Target) rank should be same, "
659+
"but received X rank(%d) != Target rank(%d)",
660+
x_shape.size(),
661+
label_shape.size()));
662+
663+
for (size_t i = 0; i < x_shape.size(); ++i) {
664+
infer_context->AddEqualCstr(x_shape[i], label_shape[i]);
665+
}
666+
667+
std::string reduction =
668+
op->attribute<pir::StrAttribute>("reduction").AsString();
669+
bool reduction_valid = (reduction == "mean" || reduction == "sum" ||
670+
reduction == "batchmean" || reduction == "none");
671+
PADDLE_ENFORCE_EQ(
672+
reduction_valid,
673+
true,
674+
common::errors::InvalidArgument(
675+
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'."));
676+
677+
std::vector<symbol::DimExpr> out_shape;
678+
if (reduction == "none") {
679+
out_shape = x_shape;
680+
} else {
681+
out_shape = std::vector<symbol::DimExpr>{};
682+
}
683+
infer_context->SetShapeOrDataForValue(
684+
op->result(0),
685+
symbol::ShapeOrDataDimExprs{
686+
symbol::TensorShapeOrDataDimExprs(out_shape)});
687+
688+
return true;
689+
}
690+
646691
bool KronOpInferSymbolicShape(pir::Operation *op,
647692
pir::InferSymbolicShapeContext *infer_context) {
648693
const auto &x_shape_or_data =

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose)
4949
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AccuracyCheck)
5050
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSample)
5151
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSelectStrided)
52+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(KldivLoss)
5253
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
5354
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lstsq)
5455
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -859,11 +859,10 @@ bool EditDistanceOpInferSymbolicShape(
859859
infer_context->AddEqualCstr(refs_dims[1], one);
860860
}
861861

862-
symbol::ShapeOrDataDimExprs refs_shape_or_data_exprs(
862+
symbol::ShapeOrDataDimExprs out_shape_or_data_exprs(
863863
symbol::TensorShapeOrDataDimExprs(
864864
std::vector<symbol::DimExpr>{refs_dims}));
865-
infer_context->SetShapeOrDataForValue(op->result(0),
866-
refs_shape_or_data_exprs);
865+
infer_context->SetShapeOrDataForValue(op->result(0), out_shape_or_data_exprs);
867866

868867
symbol::ShapeOrDataDimExprs single_dim_expr(symbol::TensorShapeOrDataDimExprs(
869868
std::vector<symbol::DimExpr>{symbol::DimExpr(1)}));

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,6 +2624,7 @@
26242624
func : kldiv_loss
26252625
data_type : x
26262626
backward : kldiv_loss_grad
2627+
interfaces : paddle::dialect::InferSymbolicShapeInterface
26272628

26282629
- op : kron
26292630
args : (Tensor x, Tensor y)

0 commit comments

Comments
 (0)