@@ -733,7 +733,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
733733 size_t logits_rank = logits_dims.size ();
734734 auto axis = logits_rank - 1 ;
735735
736- for (int i = 0 ; i < logits_rank; i++) {
736+ for (size_t i = 0 ; i < logits_rank; i++) {
737737 if (i != axis && op->attribute <pir::BoolAttribute>(" is_runtime" ).data ()) {
738738 infer_context->AddBroadcastableCstr (logits_dims[i], labels_dims[i]);
739739 }
@@ -742,7 +742,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
742742 const auto one = symbol::DimExpr{1 };
743743
744744 if (labels_dims.size () > 1 ) {
745- infer_context->AddEqualCstr (logits_dims[axis], labels_dims[ 1 ] );
745+ infer_context->AddEqualCstr (logits_dims[axis], one );
746746 }
747747
748748 infer_context->SetShapeOrDataForValue (
@@ -759,7 +759,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
759759
760760 return true ;
761761}
762-
762+
763763// bool PullBoxSparseOpInferSymbolicShape(pir::Operation *op,
764764// pir::InferSymbolicShapeContext
765765// *infer_context) {
@@ -780,7 +780,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
780780// // pass
781781// return true;
782782// }
783-
783+
784784bool SearchsortedOpInferSymbolicShape (
785785 pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
786786 // The shape of output is the same as input `values` (op->operand_source(1))
0 commit comments