Skip to content

Commit d6b9b1d

Browse files
committed
equal constrain
1 parent 1c4e9aa commit d6b9b1d

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,17 +567,34 @@ bool MarginCrossEntropyOpInferSymbolicShape(
567567
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
568568
const auto &logits_shape_or_data =
569569
infer_context->GetShapeOrDataForValue(op->operand_source(0));
570+
const auto &labels_shape_or_data =
571+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
570572

571573
std::vector<symbol::DimExpr> logits_dims = logits_shape_or_data.shape();
574+
std::vector<symbol::DimExpr> labels_dims = labels_shape_or_data.shape();
572575

573576
size_t logits_rank = logits_dims.size();
574577
auto axis = logits_rank - 1;
575578

579+
for (int i = 0; i < logits_rank; i++) {
580+
if (i != axis && op->attribute<pir::BoolAttribute>("is_runtime").data()) {
581+
infer_context->AddBroadcastableCstr(logits_dims[i], labels_dims[i]);
582+
}
583+
}
584+
585+
const auto one = symbol::DimExpr{1};
586+
587+
if (labels_dims.size() > 1) {
588+
infer_context->AddEqualCstr(logits_dims[axis], labels_dims[1]);
589+
}
590+
576591
infer_context->SetShapeOrDataForValue(
577592
op->result(0),
578593
symbol::ShapeOrDataDimExprs{
579594
symbol::TensorShapeOrDataDimExprs(logits_dims)});
595+
580596
logits_dims[axis] = symbol::DimExpr(1);
597+
581598
infer_context->SetShapeOrDataForValue(
582599
op->result(1),
583600
symbol::ShapeOrDataDimExprs{

0 commit comments

Comments
 (0)