@@ -89,6 +89,13 @@ class Pool2dOpPattern
89
89
op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
90
90
return false ;
91
91
}
92
+ paddle::dialect::FullIntArrayOp full_int_array_op =
93
+ pir::GetDefiningOpForInput (op, 1 )
94
+ ->dyn_cast <paddle::dialect::FullIntArrayOp>();
95
+ if (!full_int_array_op) {
96
+ VLOG (3 ) << " Cannot find FullIntArrayOp" ;
97
+ return false ;
98
+ }
92
99
auto padding_attr = op->attribute <pir::ArrayAttribute>(" paddings" );
93
100
std::vector<int32_t > paddings;
94
101
for (const auto &attr : padding_attr.AsVector ()) {
@@ -122,28 +129,19 @@ class Pool2dOpPattern
122
129
if (!op->attribute <pir::BoolAttribute>(" global_pooling" ).data ()) {
123
130
if (op->HasAttribute (" exclusive" )) {
124
131
if (op->attribute <pir::BoolAttribute>(" exclusive" ).data ()) {
125
- paddle::dialect::FullIntArrayOp full_int_array_op =
126
- pir::GetDefiningOpForInput (op, 1 )
127
- ->dyn_cast <paddle::dialect::FullIntArrayOp>();
128
- if (!full_int_array_op) {
129
- VLOG (3 ) << " Cannot find FullIntArrayOp" ;
130
- return false ;
131
- } else {
132
- auto attr_value =
133
- full_int_array_op->attribute <pir::ArrayAttribute>(
134
- " value" );
135
- std::vector<int64_t > kernel_size;
136
- for (const auto &attr : attr_value.AsVector ()) {
137
- kernel_size.push_back (
138
- attr.dyn_cast <pir::Int64Attribute>().data ());
139
- }
140
- for (size_t i = 0 ; i < kernel_size.size (); ++i) {
141
- if (kernel_size[i] <= paddings[i]) {
142
- VLOG (3 ) << " the padding size should be less than the "
143
- " filter size "
144
- " for exclusive-counting pooling." ;
145
- return false ;
146
- }
132
+ auto attr_value =
133
+ full_int_array_op->attribute <pir::ArrayAttribute>(" value" );
134
+ std::vector<int64_t > kernel_size;
135
+ for (const auto &attr : attr_value.AsVector ()) {
136
+ kernel_size.push_back (
137
+ attr.dyn_cast <pir::Int64Attribute>().data ());
138
+ }
139
+ for (size_t i = 0 ; i < kernel_size.size (); ++i) {
140
+ if (kernel_size[i] <= paddings[i]) {
141
+ VLOG (3 ) << " the padding size should be less than the "
142
+ " filter size "
143
+ " for exclusive-counting pooling." ;
144
+ return false ;
147
145
}
148
146
}
149
147
}
@@ -796,42 +794,42 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
796
794
return false ;
797
795
}
798
796
799
- paddle::dialect::FullOp full_op =
800
- pir::GetDefiningOpForInput (op, 2 )->dyn_cast <paddle::dialect::FullOp>();
801
- if (!full_op) {
802
- VLOG (3 ) << " Can not find full op" ;
797
+ pir::Value axis_tensor = op.operand_source (2 );
798
+ if (!axis_tensor) {
799
+ VLOG (3 ) << " pd_op.split can not find axis input" ;
803
800
return false ;
804
- } else {
801
+ }
802
+ auto out_vector_type = op.result (0 ).type ().dyn_cast <pir::VectorType>();
803
+ if (pir::GetDefiningOpForInput (op, 2 )->isa <paddle::dialect::FullOp>()) {
804
+ paddle::dialect::FullOp full_op =
805
+ pir::GetDefiningOpForInput (op, 2 )
806
+ ->dyn_cast <paddle::dialect::FullOp>();
805
807
auto axis = full_op->attribute <paddle::dialect::ScalarAttribute>(" value" )
806
808
.data ()
807
809
.to <int >();
808
810
auto x_shape = op.operand_source (0 )
809
811
.type ()
810
812
.dyn_cast <paddle::dialect::DenseTensorType>()
811
813
.dims ();
812
- auto out_vector_type = op.result (0 ).type ().dyn_cast <pir::VectorType>();
813
814
814
- paddle::dialect::FullIntArrayOp full_sections_op =
815
- pir::GetDefiningOpForInput (op, 1 )
816
- ->dyn_cast <paddle::dialect::FullIntArrayOp>();
817
- if (!full_sections_op) {
818
- VLOG (3 ) << " Can not find FullIntArrayOp" ;
815
+ axis += (axis < 0 ) ? x_shape.size () : 0 ;
816
+
817
+ if (x_shape[axis] == -1 ) {
818
+ VLOG (3 ) << " The (" << axis << " ) dim of input should not be -1" ;
819
819
return false ;
820
820
}
821
+ }
821
822
823
+ if (pir::GetDefiningOpForInput (op, 1 )
824
+ ->isa <paddle::dialect::FullIntArrayOp>()) {
825
+ paddle::dialect::FullIntArrayOp full_sections_op =
826
+ pir::GetDefiningOpForInput (op, 1 )
827
+ ->dyn_cast <paddle::dialect::FullIntArrayOp>();
822
828
auto sections = full_sections_op->attribute <pir::ArrayAttribute>(" value" );
823
-
824
829
std::vector<int64_t > output_lengths;
825
830
for (const auto &attr : sections.AsVector ()) {
826
831
output_lengths.push_back (attr.dyn_cast <pir::Int64Attribute>().data ());
827
832
}
828
- axis += (axis < 0 ) ? x_shape.size () : 0 ;
829
-
830
- if (x_shape[axis] == -1 ) {
831
- VLOG (3 ) << " The (" << axis << " ) dim of input should not be -1" ;
832
- return false ;
833
- }
834
-
835
833
if (output_lengths.size () != out_vector_type.size ()) {
836
834
VLOG (3 ) << " The output_length should be equal to the output size." ;
837
835
return false ;
@@ -853,33 +851,38 @@ class SplitWithNumOpPattern
853
851
op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
854
852
return false ;
855
853
}
856
- paddle::dialect::FullOp full_op =
857
- pir::GetDefiningOpForInput (op, 1 )-> dyn_cast <paddle::dialect::FullOp>( );
858
- if (!full_op ) {
859
- VLOG (3 ) << " Can not find full op " ;
854
+
855
+ pir::Value axis_tensor = op. operand_source ( 1 );
856
+ if (!axis_tensor ) {
857
+ VLOG (3 ) << " pd_op.split_with_num can not find axis input " ;
860
858
return false ;
861
- } else {
862
- auto axis = full_op->attribute <paddle::dialect::ScalarAttribute>(" value" )
859
+ }
860
+ if (pir::GetDefiningOpForInput (op, 1 )
861
+ ->isa <paddle::dialect::FullIntArrayOp>()) {
862
+ paddle::dialect::FullIntArrayOp full_int_array_op =
863
+ pir::GetDefiningOpForInput (op, 1 )
864
+ ->dyn_cast <paddle::dialect::FullIntArrayOp>();
865
+ auto axis = full_int_array_op
866
+ ->attribute <paddle::dialect::ScalarAttribute>(" value" )
863
867
.data ()
864
868
.to <int >();
865
869
auto x_shape = op.operand_source (0 )
866
870
.type ()
867
871
.dyn_cast <paddle::dialect::DenseTensorType>()
868
872
.dims ();
869
- auto out_vector_type = op.result (0 ).type ().dyn_cast <pir::VectorType>();
870
873
871
874
axis += (axis < 0 ) ? x_shape.size () : 0 ;
872
875
if (x_shape[axis] == -1 ) {
873
876
VLOG (3 ) << " The (" << axis << " ) dim of input should not be -1" ;
874
877
return false ;
875
878
}
876
-
877
879
if (!op->HasAttribute (" num" )) {
878
880
VLOG (3 ) << " split_with_num op must has num attributes" ;
879
881
return false ;
880
882
}
881
883
int num = op->attribute <pir::Int32Attribute>(" num" ).data ();
882
884
std::vector<int64_t > output_lengths;
885
+
883
886
if (num > 0 ) {
884
887
int64_t in_axis_dim = x_shape[axis];
885
888
if (in_axis_dim % num != 0 ) {
@@ -893,14 +896,15 @@ class SplitWithNumOpPattern
893
896
output_lengths.push_back (out_axis_dim);
894
897
}
895
898
}
896
-
899
+ auto out_vector_type = op. result ( 0 ). type (). dyn_cast <pir::VectorType>();
897
900
if (out_vector_type.size () != output_lengths.size ()) {
898
901
VLOG (3 ) << " The output_length should be equal to the output size." ;
899
902
return false ;
900
903
}
901
- op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
902
- return true ;
903
904
}
905
+
906
+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
907
+ return true ;
904
908
}
905
909
};
906
910
class GreaterEqualOpPattern
0 commit comments