@@ -428,7 +428,7 @@ bool Conv3dOpInferSymbolicShape(pir::Operation *op,
428
428
return Conv2dOpInferSymbolicShape (op, infer_context);
429
429
}
430
430
431
- bool convtransposefunction (pir::Operation *op,
431
+ bool ConvTransposeFunction (pir::Operation *op,
432
432
pir::InferSymbolicShapeContext *infer_context,
433
433
std::vector<symbol::DimExpr> output_size) {
434
434
auto x_shape =
@@ -517,14 +517,14 @@ bool convtransposefunction(pir::Operation *op,
517
517
" The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
518
518
" should be the same." ));
519
519
520
- const bool channel_last = (data_format != " NHWC" );
520
+ const bool normal_data = (data_format != " NHWC" );
521
521
const symbol::DimExpr C =
522
- (channel_last ? x_shape[1 ] : x_shape[x_shape.size () - 1 ]);
522
+ (normal_data ? x_shape[1 ] : x_shape[x_shape.size () - 1 ]);
523
523
524
524
infer_context->AddEqualCstr (filter_shape[0 ], C);
525
525
526
526
const std::vector<symbol::DimExpr> x_data_dims =
527
- channel_last
527
+ normal_data
528
528
? std::vector<symbol::DimExpr>(x_shape.begin () + 2 , x_shape.end ())
529
529
: std::vector<symbol::DimExpr>(x_shape.begin () + 1 ,
530
530
x_shape.end () - 1 );
@@ -542,10 +542,10 @@ bool convtransposefunction(pir::Operation *op,
542
542
ksize);
543
543
std::vector<symbol::DimExpr> output_shape ({x_shape[0 ]});
544
544
545
- if (channel_last ) {
545
+ if (normal_data ) {
546
546
output_shape.push_back (filter_shape[1 ] * groups);
547
547
}
548
- const int offset = (channel_last ? 2 : 1 ); // kNHWC
548
+ const int offset = (normal_data ? 2 : 1 ); // NHWC
549
549
for (int i = 0 ; i < static_cast <int >(strides.size ()); ++i) {
550
550
symbol::DimExpr filter_extent =
551
551
new_dilations[i] * (filter_shape[i + 2 ] - 1 ) + 1 ;
@@ -568,7 +568,7 @@ bool convtransposefunction(pir::Operation *op,
568
568
}
569
569
}
570
570
571
- if (!channel_last ) {
571
+ if (!normal_data ) {
572
572
output_shape.push_back (filter_shape[1 ] * groups);
573
573
}
574
574
infer_context->SetShapeOrDataForValue (
@@ -578,13 +578,6 @@ bool convtransposefunction(pir::Operation *op,
578
578
return true ;
579
579
}
580
580
581
- // bool ConvTransposeOpInferSymbolicShape(pir::Operation *op,
582
- // pir::InferSymbolicShapeContext
583
- // *infer_context) {
584
- // // pass
585
- // return true;
586
- // }
587
-
588
581
bool CrossOpInferSymbolicShape (pir::Operation *op,
589
582
pir::InferSymbolicShapeContext *infer_context) {
590
583
const auto &x_shape =
@@ -631,7 +624,7 @@ bool Conv3dTransposeOpInferSymbolicShape(
631
624
for (const auto &i : out_size) {
632
625
output_size.emplace_back (symbol::DimExpr{i});
633
626
}
634
- return convtransposefunction (op, infer_context, output_size);
627
+ return ConvTransposeFunction (op, infer_context, output_size);
635
628
}
636
629
637
630
bool Conv2dTransposeOpInferSymbolicShape (
@@ -644,15 +637,15 @@ bool Conv2dTransposeOpInferSymbolicShape(
644
637
for (const auto &i : out_size) {
645
638
output_size.emplace_back (symbol::DimExpr{i});
646
639
}
647
- return convtransposefunction (op, infer_context, output_size);
640
+ return ConvTransposeFunction (op, infer_context, output_size);
648
641
} else {
649
642
const auto &output_shape_or_data =
650
643
infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
651
644
const std::vector<symbol::DimExpr> &output_size =
652
645
output_shape_or_data.data ().has_value ()
653
646
? output_shape_or_data.data ().value ()
654
647
: output_shape_or_data.shape ();
655
- return convtransposefunction (op, infer_context, output_size);
648
+ return ConvTransposeFunction (op, infer_context, output_size);
656
649
}
657
650
}
658
651
0 commit comments