22
22
23
23
class TrtConvertConv2dTest (TrtLayerAutoScanTest ):
24
24
def is_program_valid (self , program_config : ProgramConfig ) -> bool :
25
- # TODO: This is just the example to remove the wrong attrs.
26
25
inputs = program_config .inputs
27
26
weights = program_config .weights
28
27
attrs = [
29
28
program_config .ops [i ].attrs
30
29
for i in range (len (program_config .ops ))
31
30
]
32
31
33
- # groups restriction.
34
32
if inputs ['input_data' ].shape [1 ] != weights ['conv2d_weight' ].shape [
35
33
1 ] * attrs [0 ]['groups' ]:
36
34
return False
37
35
38
- # others restriction, todo.
39
-
40
36
return True
41
37
42
38
def sample_program_configs (self ):
43
- def generate_input1 ( attrs : List [ Dict [ str , Any ]]):
44
- # TODO: This is just the example to illustrate the releation between axis and input.
45
- # for each attr, can generate different datas
39
+ self . trt_param . workspace_size = 1073741824
40
+
41
+ def generate_input1 ( batch , attrs : List [ Dict [ str , Any ]]):
46
42
if attrs [0 ]['groups' ] == 1 :
47
- return np .ones ([2 , 3 , 64 , 64 ]).astype (np .float32 )
43
+ return np .ones ([batch , 3 , 64 , 64 ]).astype (np .float32 )
44
+ elif attrs [0 ]['groups' ] == 2 :
45
+ return np .ones ([batch , 6 , 64 , 64 ]).astype (np .float32 )
48
46
else :
49
- return np .ones ([1 , 3 , 64 , 64 ]).astype (np .float32 )
47
+ return np .ones ([batch , 9 , 64 , 64 ]).astype (np .float32 )
50
48
51
49
def generate_weight1 (attrs : List [Dict [str , Any ]]):
52
50
return np .random .random ([24 , 3 , 3 , 3 ]).astype (np .float32 )
53
51
54
- # for strides in [[1, 1], [2, 2], [1, 2], [2, 3]]:
55
- # for paddings in [[0, 3], [3, 1], [1, 1, 1, 1]]:
56
- # for groups in [1, 2]:
57
- # for padding_algotithm in ['EXPLICIT', 'SAME', 'VALID']:
58
- # for dilations in [[1, 1], [1, 2]]:
59
- # for data_format in ['NCHW']:
60
- for strides in [[1 , 1 ], [2 , 2 ]]:
61
- for paddings in [[0 , 3 ], [3 , 1 ]]:
62
- for groups in [1 ]:
63
- for padding_algotithm in ['EXPLICIT' ]:
64
- for dilations in [[1 , 1 ]]:
65
- for data_format in ['NCHW' ]:
66
-
67
- dics = [{
68
- "data_fromat" : data_format ,
69
- "dilations" : dilations ,
70
- "padding_algorithm" : padding_algotithm ,
71
- "groups" : groups ,
72
- "paddings" : paddings ,
73
- "strides" : strides ,
74
- "data_format" : data_format
75
- }, {}]
76
-
77
- ops_config = [{
78
- "op_type" : "conv2d" ,
79
- "op_inputs" : {
80
- "Input" : ["input_data" ],
81
- "Filter" : ["conv2d_weight" ]
82
- },
83
- "op_outputs" : {
84
- "Output" : ["conv_output_data" ]
85
- },
86
- "op_attrs" : dics [0 ]
87
- }, {
88
- "op_type" : "relu" ,
89
- "op_inputs" : {
90
- "X" : ["conv_output_data" ]
91
- },
92
- "op_outputs" : {
93
- "Out" : ["relu_output_data" ]
94
- },
95
- "op_attrs" : dics [1 ]
96
- }]
97
- ops = self .generate_op_config (ops_config )
98
-
99
- program_config = ProgramConfig (
100
- ops = ops ,
101
- weights = {
102
- "conv2d_weight" : TensorConfig (
103
- data_gen = partial (generate_weight1 ,
104
- dics ))
105
- },
106
- inputs = {
107
- "input_data" : TensorConfig (
108
- data_gen = partial (generate_input1 ,
109
- dics ))
110
- },
111
- outputs = ["relu_output_data" ])
112
-
113
- yield program_config
52
+ for batch in [1 , 2 , 4 ]:
53
+ for strides in [[1 , 1 ], [2 , 2 ], [1 , 2 ]]:
54
+ for paddings in [[0 , 3 ], [1 , 2 , 3 , 4 ]]:
55
+ for groups in [1 , 2 , 3 ]:
56
+ for padding_algorithm in ['EXPLICIT' , 'SAME' , 'VALID' ]:
57
+ for dilations in [[1 , 1 ], [2 , 2 ], [1 , 2 ]]:
58
+ for data_format in ['NCHW' ]:
59
+
60
+ dics = [{
61
+ "data_fromat" : data_format ,
62
+ "dilations" : dilations ,
63
+ "padding_algorithm" : padding_algorithm ,
64
+ "groups" : groups ,
65
+ "paddings" : paddings ,
66
+ "strides" : strides ,
67
+ "data_format" : data_format
68
+ }, {}]
69
+
70
+ if padding_algorithm == 'EXPLICIT' :
71
+ ops_config = [{
72
+ "op_type" : "conv2d" ,
73
+ "op_inputs" : {
74
+ "Input" : ["input_data" ],
75
+ "Filter" : ["conv2d_weight" ]
76
+ },
77
+ "op_outputs" : {
78
+ "Output" : ["conv_output_data" ]
79
+ },
80
+ "op_attrs" : dics [0 ]
81
+ }, {
82
+ "op_type" : "relu" ,
83
+ "op_inputs" : {
84
+ "X" : ["conv_output_data" ]
85
+ },
86
+ "op_outputs" : {
87
+ "Out" : ["output_data" ]
88
+ },
89
+ "op_attrs" : dics [1 ]
90
+ }]
91
+ else :
92
+ ops_config = [{
93
+ "op_type" : "conv2d" ,
94
+ "op_inputs" : {
95
+ "Input" : ["input_data" ],
96
+ "Filter" : ["conv2d_weight" ]
97
+ },
98
+ "op_outputs" : {
99
+ "Output" : ["output_data" ]
100
+ },
101
+ "op_attrs" : dics [0 ]
102
+ }]
103
+ ops = self .generate_op_config (ops_config )
104
+
105
+ program_config = ProgramConfig (
106
+ ops = ops ,
107
+ weights = {
108
+ "conv2d_weight" :
109
+ TensorConfig (data_gen = partial (
110
+ generate_weight1 , dics ))
111
+ },
112
+ inputs = {
113
+ "input_data" :
114
+ TensorConfig (data_gen = partial (
115
+ generate_input1 , batch , dics ))
116
+ },
117
+ outputs = ["output_data" ])
118
+
119
+ yield program_config
114
120
115
121
def sample_predictor_configs (
116
122
self , program_config ) -> (paddle_infer .Config , List [int ], float ):
117
123
def generate_dynamic_shape (attrs ):
118
- if len ( attrs [0 ]['paddings' ]) == 4 :
124
+ if attrs [0 ]['groups' ] == 1 :
119
125
self .dynamic_shape .min_input_shape = {
120
126
"input_data" : [1 , 3 , 32 , 32 ],
121
- '' : []
127
+ "output_data" : [1 , 24 , 32 , 32 ]
122
128
}
123
129
self .dynamic_shape .max_input_shape = {
124
130
"input_data" : [4 , 3 , 64 , 64 ],
125
- '' : []
131
+ "output_data" : [4 , 24 , 64 , 64 ]
126
132
}
127
133
self .dynamic_shape .opt_input_shape = {
128
134
"input_data" : [1 , 3 , 64 , 64 ],
129
- '' : []
135
+ "output_data" : [1 , 24 , 64 , 64 ]
136
+ }
137
+ elif attrs [0 ]['groups' ] == 2 :
138
+ self .dynamic_shape .min_input_shape = {
139
+ "input_data" : [1 , 6 , 32 , 32 ],
140
+ "output_data" : [1 , 24 , 32 , 32 ]
141
+ }
142
+ self .dynamic_shape .max_input_shape = {
143
+ "input_data" : [4 , 6 , 64 , 64 ],
144
+ "output_data" : [4 , 24 , 64 , 64 ]
145
+ }
146
+ self .dynamic_shape .opt_input_shape = {
147
+ "input_data" : [1 , 6 , 64 , 64 ],
148
+ "output_data" : [1 , 24 , 64 , 64 ]
130
149
}
131
150
else :
132
151
self .dynamic_shape .min_input_shape = {
133
- "input_data" : [1 , 3 , 32 , 32 ]
152
+ "input_data" : [1 , 9 , 32 , 32 ],
153
+ "output_data" : [1 , 24 , 32 , 32 ]
134
154
}
135
155
self .dynamic_shape .max_input_shape = {
136
- "input_data" : [4 , 3 , 64 , 64 ]
156
+ "input_data" : [4 , 9 , 64 , 64 ],
157
+ "output_data" : [4 , 24 , 64 , 64 ]
137
158
}
138
159
self .dynamic_shape .opt_input_shape = {
139
- "input_data" : [1 , 3 , 64 , 64 ]
160
+ "input_data" : [1 , 9 , 64 , 64 ],
161
+ "output_data" : [1 , 24 , 64 , 64 ]
140
162
}
141
163
142
164
def clear_dynamic_shape ():
@@ -145,11 +167,7 @@ def clear_dynamic_shape():
145
167
self .dynamic_shape .opt_input_shape = {}
146
168
147
169
def generate_trt_nodes_num (attrs , dynamic_shape ):
148
- # TODO: This is just the example, need to be fixed.
149
- if len (attrs [0 ]['paddings' ]) == 4 :
150
- return 1 , 2
151
- else :
152
- return 1 , 2
170
+ return 1 , 2
153
171
154
172
attrs = [
155
173
program_config .ops [i ].attrs
@@ -169,6 +187,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
169
187
attrs , False ), (1e-5 , 1e-5 )
170
188
171
189
# for dynamic_shape
190
+
172
191
generate_dynamic_shape (attrs )
173
192
self .trt_param .precision = paddle_infer .PrecisionType .Float32
174
193
yield self .create_inference_config (), generate_trt_nodes_num (attrs ,
@@ -181,29 +200,18 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
181
200
attrs , True ), (1e-5 , 1e-5 )
182
201
183
202
def add_skip_trt_case (self ):
184
- # TODO(wilber): This is just the example to illustrate the skip usage.
185
203
def teller1 (program_config , predictor_config ):
186
- if len (program_config .ops [0 ].attrs ['paddings' ]) == 4 :
204
+ if program_config .ops [0 ].attrs [
205
+ 'padding_algorithm' ] == "SAME" or program_config .ops [
206
+ 0 ].attrs ['padding_algorithm' ] == "VALID" :
187
207
return True
188
208
return False
189
209
190
210
self .add_skip_case (
191
211
teller1 , SkipReasons .TRT_NOT_IMPLEMENTED ,
192
- "NOT Implemented: we need to add support in the future ....TODO, just for the example "
212
+ "When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op. "
193
213
)
194
214
195
- def teller2 (program_config , predictor_config ):
196
- if (
197
- program_config .ops [0 ].attrs ['dilations' ][0 ] == 1 and
198
- program_config .ops [0 ].attrs ['dilations' ][0 ] == 2
199
- ) or program_config .ops [0 ].attrs ['padding_algorithm' ] != 'EXPLICIT' :
200
- return True
201
- return False
202
-
203
- self .add_skip_case (teller2 , SkipReasons .TRT_NOT_SUPPORT ,
204
- "TODO, just for the example" )
205
- pass
206
-
207
215
def test (self ):
208
216
self .add_skip_trt_case ()
209
217
self .run_test ()
0 commit comments