Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timing cache to accelerate consequent .engine export #13386

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,9 @@ def export_coreml(model, im, file, int8, half, nms, mlmodel, prefix=colorstr("Co


@try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr("TensorRT:")):
def export_engine(
model, im, file, half, dynamic, simplify, workspace=4, verbose=False, cache="", prefix=colorstr("TensorRT:")
):
"""
Export a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0.

Expand All @@ -606,6 +608,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
simplify (bool): Set to True to simplify the model during export.
workspace (int): Workspace size in GB (default is 4).
verbose (bool): Set to True for verbose logging output.
cache (str): Path to save the TensorRT timing cache.
prefix (str): Log message prefix.

Returns:
Expand Down Expand Up @@ -660,6 +663,11 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)
else: # TensorRT versions 7, 8
config.max_workspace_size = workspace * 1 << 30
if cache: # enable timing cache
Path(cache).parent.mkdir(parents=True, exist_ok=True)
buf = Path(cache).read_bytes() if Path(cache).exists() else b""
timing_cache = config.create_timing_cache(buf)
config.set_timing_cache(timing_cache, ignore_mismatch=True)
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
Expand Down Expand Up @@ -688,6 +696,9 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(f, "wb") as t:
t.write(engine if is_trt10 else engine.serialize())
if cache: # save timing cache
with open(cache, "wb") as c:
c.write(config.get_timing_cache().serialize())
return f, None


Expand Down Expand Up @@ -1277,6 +1288,7 @@ def run(
int8=False, # CoreML/TF INT8 quantization
per_tensor=False, # TF per tensor quantization
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
cache="", # TensorRT: timing cache path
simplify=False, # ONNX: simplify model
mlmodel=False, # CoreML: Export in *.mlmodel format
opset=12, # ONNX: opset version
Expand Down Expand Up @@ -1306,6 +1318,7 @@ def run(
int8 (bool): Apply INT8 quantization for CoreML or TensorFlow models. Default is False.
per_tensor (bool): Apply per tensor quantization for TensorFlow models. Default is False.
dynamic (bool): Enable dynamic axes for ONNX, TensorFlow, or TensorRT exports. Default is False.
cache (str): TensorRT timing cache path. Default is an empty string.
simplify (bool): Simplify the ONNX model during export. Default is False.
opset (int): ONNX opset version. Default is 12.
verbose (bool): Enable verbose logging for TensorRT export. Default is False.
Expand Down Expand Up @@ -1341,6 +1354,7 @@ def run(
int8=False,
per_tensor=False,
dynamic=False,
cache="",
simplify=False,
opset=12,
verbose=False,
Expand Down Expand Up @@ -1378,7 +1392,8 @@ def run(
# Input
gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
ch = next(model.parameters()).size(1) # require input image channels
im = torch.zeros(batch_size, ch, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection

# Update model
model.eval()
Expand All @@ -1402,7 +1417,7 @@ def run(
if jit: # TorchScript
f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose, cache)
if onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
if xml: # OpenVINO
Expand Down Expand Up @@ -1497,6 +1512,7 @@ def parse_opt(known=False):
parser.add_argument("--int8", action="store_true", help="CoreML/TF/OpenVINO INT8 quantization")
parser.add_argument("--per-tensor", action="store_true", help="TF per-tensor quantization")
parser.add_argument("--dynamic", action="store_true", help="ONNX/TF/TensorRT: dynamic axes")
parser.add_argument("--cache", type=str, default="", help="TensorRT: timing cache file path")
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model")
parser.add_argument("--mlmodel", action="store_true", help="CoreML: Export in *.mlmodel format")
parser.add_argument("--opset", type=int, default=17, help="ONNX: opset version")
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,10 @@ def main(opt, callbacks=Callbacks()):
"perspective": (True, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": (True, 0.0, 1.0), # image flip up-down (probability)
"fliplr": (True, 0.0, 1.0), # image flip left-right (probability)
"mosaic": (True, 0.0, 1.0), # image mixup (probability)
"mosaic": (True, 0.0, 1.0), # image mosaic (probability)
"mixup": (True, 0.0, 1.0), # image mixup (probability)
"copy_paste": (True, 0.0, 1.0),
} # segment copy-paste (probability)
"copy_paste": (True, 0.0, 1.0), # segment copy-paste (probability)
}

# GA configs
pop_size = 50
Expand Down