File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -567,17 +567,34 @@ bool MarginCrossEntropyOpInferSymbolicShape(
567567 pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
568568 const auto &logits_shape_or_data =
569569 infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
570+ const auto &labels_shape_or_data =
571+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
570572
571573 std::vector<symbol::DimExpr> logits_dims = logits_shape_or_data.shape ();
574+ std::vector<symbol::DimExpr> labels_dims = labels_shape_or_data.shape ();
572575
573576 size_t logits_rank = logits_dims.size ();
574577 auto axis = logits_rank - 1 ;
575578
579+ for (int i = 0 ; i < logits_rank; i++) {
580+ if (i != axis && op->attribute <pir::BoolAttribute>(" is_runtime" ).data ()) {
581+ infer_context->AddBroadcastableCstr (logits_dims[i], labels_dims[i]);
582+ }
583+ }
584+
585+ const auto one = symbol::DimExpr{1 };
586+
587+ if (labels_dims.size () > 1 ) {
588+ infer_context->AddEqualCstr (logits_dims[axis], labels_dims[1 ]);
589+ }
590+
576591 infer_context->SetShapeOrDataForValue (
577592 op->result (0 ),
578593 symbol::ShapeOrDataDimExprs{
579594 symbol::TensorShapeOrDataDimExprs (logits_dims)});
595+
580596 logits_dims[axis] = symbol::DimExpr (1 );
597+
581598 infer_context->SetShapeOrDataForValue (
582599 op->result (1 ),
583600 symbol::ShapeOrDataDimExprs{
You can’t perform that action at this time.
0 commit comments