Skip to content

Commit c946b38

Browse files
committed
add trt min max opt shape
1 parent 9a68a61 commit c946b38

File tree

1 file changed

+77
-3
lines changed

1 file changed

+77
-3
lines changed

tools/infer/utility.py

+77-3
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,83 @@ def create_predictor(args, mode, logger):
139139
config.enable_use_gpu(args.gpu_mem, 0)
140140
if args.use_tensorrt:
141141
config.enable_tensorrt_engine(
142-
precision_mode=inference.PrecisionType.Half
143-
if args.use_fp16 else inference.PrecisionType.Float32,
144-
max_batch_size=args.max_batch_size)
142+
precision_mode=inference.PrecisionType.Float32,
143+
max_batch_size=args.max_batch_size,
144+
min_subgraph_size=3) # skip the minmum trt subgraph
145+
if mode == "det" and "mobile" in model_file_path:
146+
min_input_shape = {
147+
"x": [1, 3, 50, 50],
148+
"conv2d_92.tmp_0": [1, 96, 20, 20],
149+
"conv2d_91.tmp_0": [1, 96, 10, 10],
150+
"nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
151+
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
152+
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
153+
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
154+
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20],
155+
"elementwise_add_7": [1, 56, 2, 2],
156+
"nearest_interp_v2_0.tmp_0": [1, 96, 2, 2]
157+
}
158+
max_input_shape = {
159+
"x": [1, 3, 2000, 2000],
160+
"conv2d_92.tmp_0": [1, 96, 400, 400],
161+
"conv2d_91.tmp_0": [1, 96, 200, 200],
162+
"nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
163+
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
164+
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
165+
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
166+
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400],
167+
"elementwise_add_7": [1, 56, 400, 400],
168+
"nearest_interp_v2_0.tmp_0": [1, 96, 400, 400]
169+
}
170+
opt_input_shape = {
171+
"x": [1, 3, 640, 640],
172+
"conv2d_92.tmp_0": [1, 96, 160, 160],
173+
"conv2d_91.tmp_0": [1, 96, 80, 80],
174+
"nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
175+
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
176+
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
177+
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
178+
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160],
179+
"elementwise_add_7": [1, 56, 40, 40],
180+
"nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
181+
}
182+
if mode == "det" and "server" in model_file_path:
183+
min_input_shape = {
184+
"x": [1, 3, 50, 50],
185+
"conv2d_59.tmp_0": [1, 96, 20, 20],
186+
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
187+
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
188+
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
189+
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
190+
}
191+
max_input_shape = {
192+
"x": [1, 3, 2000, 2000],
193+
"conv2d_59.tmp_0": [1, 96, 400, 400],
194+
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
195+
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
196+
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
197+
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
198+
}
199+
opt_input_shape = {
200+
"x": [1, 3, 640, 640],
201+
"conv2d_59.tmp_0": [1, 96, 160, 160],
202+
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
203+
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
204+
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
205+
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
206+
}
207+
elif mode == "rec":
208+
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
209+
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
210+
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
211+
elif mode == "cls":
212+
min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]}
213+
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
214+
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
215+
216+
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
217+
opt_input_shape)
218+
145219
else:
146220
config.disable_gpu()
147221
if hasattr(args, "cpu_threads"):

0 commit comments

Comments
 (0)