Skip to content

Commit b455a17

Browse files
committed
fix comment:
fix comment fix comment
1 parent 954ec66 commit b455a17

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ bool Conv3dOpInferSymbolicShape(pir::Operation *op,
428428
return Conv2dOpInferSymbolicShape(op, infer_context);
429429
}
430430

431-
bool convtransposefunction(pir::Operation *op,
431+
bool ConvTransposeFunction(pir::Operation *op,
432432
pir::InferSymbolicShapeContext *infer_context,
433433
std::vector<symbol::DimExpr> output_size) {
434434
auto x_shape =
@@ -517,14 +517,14 @@ bool convtransposefunction(pir::Operation *op,
517517
"The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
518518
"should be the same."));
519519

520-
const bool channel_last = (data_format != "NHWC");
520+
const bool normal_data = (data_format != "NHWC");
521521
const symbol::DimExpr C =
522-
(channel_last ? x_shape[1] : x_shape[x_shape.size() - 1]);
522+
(normal_data ? x_shape[1] : x_shape[x_shape.size() - 1]);
523523

524524
infer_context->AddEqualCstr(filter_shape[0], C);
525525

526526
const std::vector<symbol::DimExpr> x_data_dims =
527-
channel_last
527+
normal_data
528528
? std::vector<symbol::DimExpr>(x_shape.begin() + 2, x_shape.end())
529529
: std::vector<symbol::DimExpr>(x_shape.begin() + 1,
530530
x_shape.end() - 1);
@@ -542,10 +542,10 @@ bool convtransposefunction(pir::Operation *op,
542542
ksize);
543543
std::vector<symbol::DimExpr> output_shape({x_shape[0]});
544544

545-
if (channel_last) {
545+
if (normal_data) {
546546
output_shape.push_back(filter_shape[1] * groups);
547547
}
548-
const int offset = (channel_last ? 2 : 1); // kNHWC
548+
const int offset = (normal_data ? 2 : 1); // NHWC
549549
for (int i = 0; i < static_cast<int>(strides.size()); ++i) {
550550
symbol::DimExpr filter_extent =
551551
new_dilations[i] * (filter_shape[i + 2] - 1) + 1;
@@ -568,7 +568,7 @@ bool convtransposefunction(pir::Operation *op,
568568
}
569569
}
570570

571-
if (!channel_last) {
571+
if (!normal_data) {
572572
output_shape.push_back(filter_shape[1] * groups);
573573
}
574574
infer_context->SetShapeOrDataForValue(
@@ -578,13 +578,6 @@ bool convtransposefunction(pir::Operation *op,
578578
return true;
579579
}
580580

581-
// bool ConvTransposeOpInferSymbolicShape(pir::Operation *op,
582-
// pir::InferSymbolicShapeContext
583-
// *infer_context) {
584-
// // pass
585-
// return true;
586-
// }
587-
588581
bool CrossOpInferSymbolicShape(pir::Operation *op,
589582
pir::InferSymbolicShapeContext *infer_context) {
590583
const auto &x_shape =
@@ -631,7 +624,7 @@ bool Conv3dTransposeOpInferSymbolicShape(
631624
for (const auto &i : out_size) {
632625
output_size.emplace_back(symbol::DimExpr{i});
633626
}
634-
return convtransposefunction(op, infer_context, output_size);
627+
return ConvTransposeFunction(op, infer_context, output_size);
635628
}
636629

637630
bool Conv2dTransposeOpInferSymbolicShape(
@@ -644,15 +637,15 @@ bool Conv2dTransposeOpInferSymbolicShape(
644637
for (const auto &i : out_size) {
645638
output_size.emplace_back(symbol::DimExpr{i});
646639
}
647-
return convtransposefunction(op, infer_context, output_size);
640+
return ConvTransposeFunction(op, infer_context, output_size);
648641
} else {
649642
const auto &output_shape_or_data =
650643
infer_context->GetShapeOrDataForValue(op->operand_source(2));
651644
const std::vector<symbol::DimExpr> &output_size =
652645
output_shape_or_data.data().has_value()
653646
? output_shape_or_data.data().value()
654647
: output_shape_or_data.shape();
655-
return convtransposefunction(op, infer_context, output_size);
648+
return ConvTransposeFunction(op, infer_context, output_size);
656649
}
657650
}
658651

0 commit comments

Comments
 (0)