Skip to content

Commit 1fecfe5

Browse files
committed
change
change change
1 parent 4e8ccc0 commit 1fecfe5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
733733
size_t logits_rank = logits_dims.size();
734734
auto axis = logits_rank - 1;
735735

736-
for (int i = 0; i < logits_rank; i++) {
736+
for (size_t i = 0; i < logits_rank; i++) {
737737
if (i != axis && op->attribute<pir::BoolAttribute>("is_runtime").data()) {
738738
infer_context->AddBroadcastableCstr(logits_dims[i], labels_dims[i]);
739739
}
@@ -742,7 +742,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
742742
const auto one = symbol::DimExpr{1};
743743

744744
if (labels_dims.size() > 1) {
745-
infer_context->AddEqualCstr(logits_dims[axis], labels_dims[1]);
745+
infer_context->AddEqualCstr(logits_dims[axis], one);
746746
}
747747

748748
infer_context->SetShapeOrDataForValue(
@@ -759,7 +759,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
759759

760760
return true;
761761
}
762-
762+
763763
// bool PullBoxSparseOpInferSymbolicShape(pir::Operation *op,
764764
// pir::InferSymbolicShapeContext
765765
// *infer_context) {
@@ -780,7 +780,7 @@ bool MarginCrossEntropyOpInferSymbolicShape(
780780
// // pass
781781
// return true;
782782
// }
783-
783+
784784
bool SearchsortedOpInferSymbolicShape(
785785
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
786786
// The shape of output is the same as input `values` (op->operand_source(1))

0 commit comments

Comments
 (0)