Skip to content

Commit 7bf04d9

Browse files
authored
AutoShape() models as DetectMultiBackend() instances (#5845)
* Update AutoShape() * autodownload ONNX * Cleanup * Finish updates * Add Usage * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * fix device * Update hubconf.py * Update common.py * smart param selection * autodownload all formats * autopad only pytorch models * new_shape edits * stride tensor fix * Cleanup
1 parent d885799 commit 7bf04d9

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def parse_opt():
411411
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
412412
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
413413
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
414-
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
414+
parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version')
415415
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
416416
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
417417
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')

hubconf.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Usage:
66
import torch
77
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
8+
model = torch.hub.load('ultralytics/yolov5:master', 'custom', 'path/to/yolov5s.onnx') # file from branch
89
"""
910

1011
import torch
@@ -27,26 +28,25 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
2728
"""
2829
from pathlib import Path
2930

30-
from models.common import AutoShape
31-
from models.experimental import attempt_load
31+
from models.common import AutoShape, DetectMultiBackend
3232
from models.yolo import Model
3333
from utils.downloads import attempt_download
3434
from utils.general import check_requirements, intersect_dicts, set_logging
3535
from utils.torch_utils import select_device
3636

37-
file = Path(__file__).resolve()
3837
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
3938
set_logging(verbose=verbose)
4039

41-
save_dir = Path('') if str(name).endswith('.pt') else file.parent
42-
path = (save_dir / name).with_suffix('.pt') # checkpoint path
40+
name = Path(name)
41+
path = name.with_suffix('.pt') if name.suffix == '' else name # checkpoint path
4342
try:
4443
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
4544

4645
if pretrained and channels == 3 and classes == 80:
47-
model = attempt_load(path, map_location=device) # download/load FP32 model
46+
model = DetectMultiBackend(path, device=device) # download/load FP32 model
47+
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
4848
else:
49-
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
49+
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.name}.yaml'))[0] # model.yaml path
5050
model = Model(cfg, channels, classes) # create model
5151
if pretrained:
5252
ckpt = torch.load(attempt_download(path), map_location=device) # load

models/common.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def forward(self, x):
276276

277277
class DetectMultiBackend(nn.Module):
278278
# YOLOv5 MultiBackend class for python inference on various backends
279-
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
279+
def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
280280
# Usage:
281281
# PyTorch: weights = *.pt
282282
# TorchScript: *.torchscript
@@ -287,13 +287,16 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
287287
# ONNX Runtime: *.onnx
288288
# OpenCV DNN: *.onnx with dnn=True
289289
# TensorRT: *.engine
290+
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
291+
290292
super().__init__()
291293
w = str(weights[0] if isinstance(weights, list) else weights)
292294
suffix = Path(w).suffix.lower()
293295
suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
294296
check_suffix(w, suffixes) # check weights have acceptable suffix
295297
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
296298
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
299+
attempt_download(w) # download if not local
297300

298301
if jit: # TorchScript
299302
LOGGER.info(f'Loading {w} for TorchScript inference...')
@@ -303,11 +306,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
303306
d = json.loads(extra_files['config.txt']) # extra_files dict
304307
stride, names = int(d['stride']), d['names']
305308
elif pt: # PyTorch
306-
from models.experimental import attempt_load # scoped to avoid circular import
307309
model = attempt_load(weights, map_location=device)
308310
stride = int(model.stride.max()) # model stride
309311
names = model.module.names if hasattr(model, 'module') else model.names # get class names
312+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
310313
elif coreml: # CoreML
314+
LOGGER.info(f'Loading {w} for CoreML inference...')
311315
import coremltools as ct
312316
model = ct.models.MLModel(w)
313317
elif dnn: # ONNX OpenCV DNN
@@ -316,7 +320,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
316320
net = cv2.dnn.readNetFromONNX(w)
317321
elif onnx: # ONNX Runtime
318322
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
319-
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
323+
check_requirements(('onnx', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
320324
import onnxruntime
321325
session = onnxruntime.InferenceSession(w, None)
322326
elif engine: # TensorRT
@@ -376,7 +380,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
376380
if self.pt: # PyTorch
377381
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
378382
return y if val else y[0]
379-
elif self.coreml: # CoreML *.mlmodel
383+
elif self.coreml: # CoreML
380384
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
381385
im = Image.fromarray((im[0] * 255).astype('uint8'))
382386
# im = im.resize((192, 320), Image.ANTIALIAS)
@@ -433,24 +437,28 @@ class AutoShape(nn.Module):
433437
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
434438
conf = 0.25 # NMS confidence threshold
435439
iou = 0.45 # NMS IoU threshold
436-
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
440+
agnostic = False # NMS class-agnostic
437441
multi_label = False # NMS multiple labels per box
442+
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
438443
max_det = 1000 # maximum number of detections per image
439444

440445
def __init__(self, model):
441446
super().__init__()
442447
LOGGER.info('Adding AutoShape... ')
443448
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
449+
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
450+
self.pt = not self.dmb or model.pt # PyTorch model
444451
self.model = model.eval()
445452

446453
def _apply(self, fn):
447454
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
448455
self = super()._apply(fn)
449-
m = self.model.model[-1] # Detect()
450-
m.stride = fn(m.stride)
451-
m.grid = list(map(fn, m.grid))
452-
if isinstance(m.anchor_grid, list):
453-
m.anchor_grid = list(map(fn, m.anchor_grid))
456+
if self.pt:
457+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
458+
m.stride = fn(m.stride)
459+
m.grid = list(map(fn, m.grid))
460+
if isinstance(m.anchor_grid, list):
461+
m.anchor_grid = list(map(fn, m.anchor_grid))
454462
return self
455463

456464
@torch.no_grad()
@@ -465,7 +473,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
465473
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
466474

467475
t = [time_sync()]
468-
p = next(self.model.parameters()) # for device and type
476+
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
469477
if isinstance(imgs, torch.Tensor): # torch
470478
with amp.autocast(enabled=p.device.type != 'cpu'):
471479
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
@@ -489,21 +497,21 @@ def forward(self, imgs, size=640, augment=False, profile=False):
489497
g = (size / max(s)) # gain
490498
shape1.append([y * g for y in s])
491499
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
492-
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
493-
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
500+
shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)] # inference shape
501+
x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs] # pad
494502
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
495503
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
496504
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
497505
t.append(time_sync())
498506

499507
with amp.autocast(enabled=p.device.type != 'cpu'):
500508
# Inference
501-
y = self.model(x, augment, profile)[0] # forward
509+
y = self.model(x, augment, profile) # forward
502510
t.append(time_sync())
503511

504512
# Post-process
505-
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
506-
multi_label=self.multi_label, max_det=self.max_det) # NMS
513+
y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes,
514+
agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS
507515
for i in range(n):
508516
scale_coords(shape1, y[i][:, :4], shape0[i])
509517

utils/general.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,9 @@ def download_one(url, dir):
455455

456456

457457
def make_divisible(x, divisor):
458-
# Returns x evenly divisible by divisor
458+
# Returns nearest x divisible by divisor
459+
if isinstance(divisor, torch.Tensor):
460+
divisor = int(divisor.max()) # to int
459461
return math.ceil(x / divisor) * divisor
460462

461463

0 commit comments

Comments
 (0)