@@ -276,7 +276,7 @@ def forward(self, x):
276
276
277
277
class DetectMultiBackend (nn .Module ):
278
278
# YOLOv5 MultiBackend class for python inference on various backends
279
- def __init__ (self , weights = 'yolov5s.pt' , device = None , dnn = True ):
279
+ def __init__ (self , weights = 'yolov5s.pt' , device = None , dnn = False ):
280
280
# Usage:
281
281
# PyTorch: weights = *.pt
282
282
# TorchScript: *.torchscript
@@ -287,13 +287,16 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
287
287
# ONNX Runtime: *.onnx
288
288
# OpenCV DNN: *.onnx with dnn=True
289
289
# TensorRT: *.engine
290
+ from models .experimental import attempt_download , attempt_load # scoped to avoid circular import
291
+
290
292
super ().__init__ ()
291
293
w = str (weights [0 ] if isinstance (weights , list ) else weights )
292
294
suffix = Path (w ).suffix .lower ()
293
295
suffixes = ['.pt' , '.torchscript' , '.onnx' , '.engine' , '.tflite' , '.pb' , '' , '.mlmodel' ]
294
296
check_suffix (w , suffixes ) # check weights have acceptable suffix
295
297
pt , jit , onnx , engine , tflite , pb , saved_model , coreml = (suffix == x for x in suffixes ) # backend booleans
296
298
stride , names = 64 , [f'class{ i } ' for i in range (1000 )] # assign defaults
299
+ attempt_download (w ) # download if not local
297
300
298
301
if jit : # TorchScript
299
302
LOGGER .info (f'Loading { w } for TorchScript inference...' )
@@ -303,11 +306,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
303
306
d = json .loads (extra_files ['config.txt' ]) # extra_files dict
304
307
stride , names = int (d ['stride' ]), d ['names' ]
305
308
elif pt : # PyTorch
306
- from models .experimental import attempt_load # scoped to avoid circular import
307
309
model = attempt_load (weights , map_location = device )
308
310
stride = int (model .stride .max ()) # model stride
309
311
names = model .module .names if hasattr (model , 'module' ) else model .names # get class names
312
+ self .model = model # explicitly assign for to(), cpu(), cuda(), half()
310
313
elif coreml : # CoreML
314
+ LOGGER .info (f'Loading { w } for CoreML inference...' )
311
315
import coremltools as ct
312
316
model = ct .models .MLModel (w )
313
317
elif dnn : # ONNX OpenCV DNN
@@ -316,7 +320,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
316
320
net = cv2 .dnn .readNetFromONNX (w )
317
321
elif onnx : # ONNX Runtime
318
322
LOGGER .info (f'Loading { w } for ONNX Runtime inference...' )
319
- check_requirements (('onnx' , 'onnxruntime-gpu' if torch .has_cuda else 'onnxruntime' ))
323
+ check_requirements (('onnx' , 'onnxruntime-gpu' if torch .cuda . is_available () else 'onnxruntime' ))
320
324
import onnxruntime
321
325
session = onnxruntime .InferenceSession (w , None )
322
326
elif engine : # TensorRT
@@ -376,7 +380,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
376
380
if self .pt : # PyTorch
377
381
y = self .model (im ) if self .jit else self .model (im , augment = augment , visualize = visualize )
378
382
return y if val else y [0 ]
379
- elif self .coreml : # CoreML *.mlmodel
383
+ elif self .coreml : # CoreML
380
384
im = im .permute (0 , 2 , 3 , 1 ).cpu ().numpy () # torch BCHW to numpy BHWC shape(1,320,192,3)
381
385
im = Image .fromarray ((im [0 ] * 255 ).astype ('uint8' ))
382
386
# im = im.resize((192, 320), Image.ANTIALIAS)
@@ -433,24 +437,28 @@ class AutoShape(nn.Module):
433
437
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
434
438
conf = 0.25 # NMS confidence threshold
435
439
iou = 0.45 # NMS IoU threshold
436
- classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
440
+ agnostic = False # NMS class-agnostic
437
441
multi_label = False # NMS multiple labels per box
442
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
438
443
max_det = 1000 # maximum number of detections per image
439
444
440
445
def __init__ (self , model ):
441
446
super ().__init__ ()
442
447
LOGGER .info ('Adding AutoShape... ' )
443
448
copy_attr (self , model , include = ('yaml' , 'nc' , 'hyp' , 'names' , 'stride' , 'abc' ), exclude = ()) # copy attributes
449
+ self .dmb = isinstance (model , DetectMultiBackend ) # DetectMultiBackend() instance
450
+ self .pt = not self .dmb or model .pt # PyTorch model
444
451
self .model = model .eval ()
445
452
446
453
def _apply (self , fn ):
447
454
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
448
455
self = super ()._apply (fn )
449
- m = self .model .model [- 1 ] # Detect()
450
- m .stride = fn (m .stride )
451
- m .grid = list (map (fn , m .grid ))
452
- if isinstance (m .anchor_grid , list ):
453
- m .anchor_grid = list (map (fn , m .anchor_grid ))
456
+ if self .pt :
457
+ m = self .model .model .model [- 1 ] if self .dmb else self .model .model [- 1 ] # Detect()
458
+ m .stride = fn (m .stride )
459
+ m .grid = list (map (fn , m .grid ))
460
+ if isinstance (m .anchor_grid , list ):
461
+ m .anchor_grid = list (map (fn , m .anchor_grid ))
454
462
return self
455
463
456
464
@torch .no_grad ()
@@ -465,7 +473,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
465
473
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
466
474
467
475
t = [time_sync ()]
468
- p = next (self .model .parameters ()) # for device and type
476
+ p = next (self .model .parameters ()) if self . pt else torch . zeros ( 1 ) # for device and type
469
477
if isinstance (imgs , torch .Tensor ): # torch
470
478
with amp .autocast (enabled = p .device .type != 'cpu' ):
471
479
return self .model (imgs .to (p .device ).type_as (p ), augment , profile ) # inference
@@ -489,21 +497,21 @@ def forward(self, imgs, size=640, augment=False, profile=False):
489
497
g = (size / max (s )) # gain
490
498
shape1 .append ([y * g for y in s ])
491
499
imgs [i ] = im if im .data .contiguous else np .ascontiguousarray (im ) # update
492
- shape1 = [make_divisible (x , int ( self .stride . max ()) ) for x in np .stack (shape1 , 0 ).max (0 )] # inference shape
493
- x = [letterbox (im , new_shape = shape1 , auto = False )[0 ] for im in imgs ] # pad
500
+ shape1 = [make_divisible (x , self .stride ) for x in np .stack (shape1 , 0 ).max (0 )] # inference shape
501
+ x = [letterbox (im , new_shape = shape1 if self . pt else size , auto = False )[0 ] for im in imgs ] # pad
494
502
x = np .stack (x , 0 ) if n > 1 else x [0 ][None ] # stack
495
503
x = np .ascontiguousarray (x .transpose ((0 , 3 , 1 , 2 ))) # BHWC to BCHW
496
504
x = torch .from_numpy (x ).to (p .device ).type_as (p ) / 255 # uint8 to fp16/32
497
505
t .append (time_sync ())
498
506
499
507
with amp .autocast (enabled = p .device .type != 'cpu' ):
500
508
# Inference
501
- y = self .model (x , augment , profile )[ 0 ] # forward
509
+ y = self .model (x , augment , profile ) # forward
502
510
t .append (time_sync ())
503
511
504
512
# Post-process
505
- y = non_max_suppression (y , self .conf , iou_thres = self .iou , classes = self .classes ,
506
- multi_label = self .multi_label , max_det = self .max_det ) # NMS
513
+ y = non_max_suppression (y if self . dmb else y [ 0 ] , self .conf , iou_thres = self .iou , classes = self .classes ,
514
+ agnostic = self . agnostic , multi_label = self .multi_label , max_det = self .max_det ) # NMS
507
515
for i in range (n ):
508
516
scale_coords (shape1 , y [i ][:, :4 ], shape0 [i ])
509
517
0 commit comments