Skip to content

Commit 932dc78

Browse files
authored
YOLOv5 Export Benchmarks for GPU (#6963)
* Add benchmarks.py GPU support * Updates * Updates * Updates * Updates * Add --half * Add TRT requirements * Cleanup * Add TF to warmup types * Update export.py * Update export.py * Update benchmarks.py
1 parent 99de551 commit 932dc78

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

export.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,18 @@
7575

7676
def export_formats():
7777
# YOLOv5 export formats
78-
x = [['PyTorch', '-', '.pt'],
79-
['TorchScript', 'torchscript', '.torchscript'],
80-
['ONNX', 'onnx', '.onnx'],
81-
['OpenVINO', 'openvino', '_openvino_model'],
82-
['TensorRT', 'engine', '.engine'],
83-
['CoreML', 'coreml', '.mlmodel'],
84-
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
85-
['TensorFlow GraphDef', 'pb', '.pb'],
86-
['TensorFlow Lite', 'tflite', '.tflite'],
87-
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
88-
['TensorFlow.js', 'tfjs', '_web_model']]
89-
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])
78+
x = [['PyTorch', '-', '.pt', True],
79+
['TorchScript', 'torchscript', '.torchscript', True],
80+
['ONNX', 'onnx', '.onnx', True],
81+
['OpenVINO', 'openvino', '_openvino_model', False],
82+
['TensorRT', 'engine', '.engine', True],
83+
['CoreML', 'coreml', '.mlmodel', False],
84+
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
85+
['TensorFlow GraphDef', 'pb', '.pb', True],
86+
['TensorFlow Lite', 'tflite', '.tflite', False],
87+
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
88+
['TensorFlow.js', 'tfjs', '_web_model', False]]
89+
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
9090

9191

9292
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):

models/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,11 @@ def forward(self, im, augment=False, visualize=False, val=False):
464464

465465
def warmup(self, imgsz=(1, 3, 640, 640)):
466466
# Warmup model by running inference once
467-
if self.pt or self.jit or self.onnx or self.engine: # warmup types
468-
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
467+
if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
468+
if self.device.type != 'cpu': # only warmup GPU models
469469
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
470-
self.forward(im) # warmup
470+
for _ in range(2 if self.jit else 1): #
471+
self.forward(im) # warmup
471472

472473
@staticmethod
473474
def model_type(p='path/to/model.pt'):

utils/benchmarks.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Requirements:
2020
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
2121
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
22+
$ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT
2223
2324
Usage:
2425
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
@@ -41,20 +42,29 @@
4142
import val
4243
from utils import notebook_init
4344
from utils.general import LOGGER, print_args
45+
from utils.torch_utils import select_device
4446

4547

4648
def run(weights=ROOT / 'yolov5s.pt', # weights path
4749
imgsz=640, # inference size (pixels)
4850
batch_size=1, # batch size
4951
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
52+
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
53+
half=False, # use FP16 half-precision inference
5054
):
5155
y, t = [], time.time()
5256
formats = export.export_formats()
53-
for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix)
57+
device = select_device(device)
58+
for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
5459
try:
55-
w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1]
60+
if device.type != 'cpu':
61+
assert gpu, f'{name} inference not supported on GPU'
62+
if f == '-':
63+
w = weights # PyTorch format
64+
else:
65+
w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
5666
assert suffix in str(w), 'export failed'
57-
result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark')
67+
result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half)
5868
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
5969
speeds = result[2] # times (preprocess, inference, postprocess)
6070
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
@@ -78,6 +88,8 @@ def parse_opt():
7888
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
7989
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
8090
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
91+
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
92+
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
8193
opt = parser.parse_args()
8294
print_args(FILE.stem, opt)
8395
return opt

0 commit comments

Comments
 (0)