Skip to content

Commit 596de6d

Browse files
Default FP16 TensorRT export (#6798)
* Assert engine precision #6777 * Default to FP32 inputs for TensorRT engines * Default to FP16 TensorRT exports #6777 * Remove wrong line #6777 * Automatically adjust detect.py input precision #6777 * Automatically adjust val.py input precision #6777 * Add missing colon * Cleanup * Cleanup * Remove default trt_fp16_input definition * Experiment * Reorder detect.py if statement to after half checks * Update common.py * Update export.py * Cleanup Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
1 parent 7e98b48 commit 596de6d

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

detect.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
9797
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA
9898
if pt or jit:
9999
model.model.half() if half else model.model.float()
100+
elif engine and model.trt_fp16_input != half:
101+
LOGGER.info('model ' + (
102+
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.')
103+
half = model.trt_fp16_input
100104

101105
# Dataloader
102106
if webcam:

export.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
233233
for out in outputs:
234234
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
235235

236-
half &= builder.platform_has_fast_fp16
237-
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
238-
if half:
236+
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 else 32} engine in {f}')
237+
if builder.platform_has_fast_fp16:
239238
config.set_flag(trt.BuilderFlag.FP16)
240239
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
241240
t.write(engine.serialize())

models/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
338338
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
339339
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
340340
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
341+
trt_fp16_input = False
341342
logger = trt.Logger(trt.Logger.INFO)
342343
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
343344
model = runtime.deserialize_cuda_engine(f.read())
@@ -348,6 +349,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
348349
shape = tuple(model.get_binding_shape(index))
349350
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
350351
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
352+
if model.binding_is_input(index) and dtype == np.float16:
353+
trt_fp16_input = True
351354
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
352355
context = model.create_execution_context()
353356
batch_size = bindings['images'].shape[0]

val.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def run(data,
144144
model.model.half() if half else model.model.float()
145145
elif engine:
146146
batch_size = model.batch_size
147+
if model.trt_fp16_input != half:
148+
LOGGER.info('model ' + (
149+
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.')
150+
half = model.trt_fp16_input
147151
else:
148152
half = False
149153
batch_size = 1 # export.py models default to batch-size 1

0 commit comments

Comments
 (0)