1
- # YOLOv5 by Ultralytics, GPL-3.0 license
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
2
"""
3
3
Common modules
4
4
"""
16
16
from PIL import Image
17
17
from torch import nn , Tensor
18
18
from torch .cuda import amp
19
- from yolort .v5 .utils .datasets import exif_transpose , letterbox
20
19
from yolort .v5 .utils .general import (
21
20
colorstr ,
22
21
increment_path ,
23
22
is_ascii ,
24
23
make_divisible ,
25
24
non_max_suppression ,
26
- save_one_box ,
27
25
scale_coords ,
28
26
xyxy2xywh ,
29
27
)
30
- from yolort .v5 .utils .plots import Annotator , colors
31
- from yolort .v5 .utils .torch_utils import time_sync
28
+ from yolort .v5 .utils .plots import Annotator , colors , save_one_box
29
+ from yolort .v5 .utils .torch_utils import copy_attr , time_sync
32
30
33
31
LOGGER = logging .getLogger (__name__ )
34
32
@@ -414,32 +412,52 @@ def forward(self, x):
414
412
415
413
416
414
class AutoShape (nn .Module ):
417
- # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs.
418
- # Includes preprocessing, inference and NMS
415
+ """
416
+ YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs.
417
+ Includes preprocessing, inference and NMS
418
+ """
419
+
419
420
conf = 0.25 # NMS confidence threshold
420
421
iou = 0.45 # NMS IoU threshold
421
- classes = None # (optional list) filter by class
422
+ # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
423
+ classes = None
422
424
multi_label = False # NMS multiple labels per box
423
425
max_det = 1000 # maximum number of detections per image
424
426
425
427
def __init__ (self , model ):
426
428
super ().__init__ ()
429
+ LOGGER .info ("Adding AutoShape... " )
430
+ # copy attributes
431
+ copy_attr (self , model , include = ("yaml" , "nc" , "hyp" , "names" , "stride" , "abc" ), exclude = ())
427
432
self .model = model .eval ()
428
433
429
- def autoshape (self ):
430
- LOGGER .info ("AutoShape already enabled, skipping... " ) # model already converted to model.autoshape()
434
+ def _apply (self , fn ):
435
+ """
436
+ Apply to(), cpu(), cuda(), half() to model tensors that
437
+ are not parameters or registered buffers
438
+ """
439
+ self = super ()._apply (fn )
440
+ m = self .model .model [- 1 ] # Detect()
441
+ m .stride = fn (m .stride )
442
+ m .grid = list (map (fn , m .grid ))
443
+ if isinstance (m .anchor_grid , list ):
444
+ m .anchor_grid = list (map (fn , m .anchor_grid ))
431
445
return self
432
446
433
447
@torch .no_grad ()
434
448
def forward (self , imgs , size = 640 , augment = False , profile = False ):
435
- # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
436
- # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
437
- # URI: = 'https://ultralytics.com/images/zidane.jpg'
438
- # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
439
- # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
440
- # numpy: = np.zeros((640,1280,3)) # HWC
441
- # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
442
- # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
449
+ """
450
+ Inference from various sources. For height=640, width=1280, RGB images example inputs are:
451
+ - file: imgs = 'data/images/zidane.jpg' # str or PosixPath
452
+ - URI: = 'https://ultralytics.com/images/zidane.jpg'
453
+ - OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
454
+ - PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
455
+ - numpy: = np.zeros((640,1280,3)) # HWC
456
+ - torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
457
+ - multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
458
+ """
459
+ from yolort .v5 .utils .augmentations import letterbox
460
+ from yolort .v5 .utils .datasets import exif_transpose
443
461
444
462
t = [time_sync ()]
445
463
p = next (self .model .parameters ()) # for device and type
@@ -448,10 +466,10 @@ def forward(self, imgs, size=640, augment=False, profile=False):
448
466
return self .model (imgs .to (p .device ).type_as (p ), augment , profile ) # inference
449
467
450
468
# Pre-process
451
- n , imgs = (
452
- (len (imgs ), imgs ) if isinstance (imgs , list ) else (1 , [imgs ])
453
- ) # number of images, list of images
454
- shape0 , shape1 , files = [], [], [] # image and inference shapes, filenames
469
+ # number of images, list of images
470
+ n , imgs = (len (imgs ), imgs ) if isinstance (imgs , list ) else (1 , [imgs ])
471
+ # image and inference shapes, filenames
472
+ shape0 , shape1 , files = [], [], []
455
473
for i , im in enumerate (imgs ):
456
474
f = f"image{ i } " # filename
457
475
if isinstance (im , (str , Path )): # filename or uri
@@ -476,7 +494,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
476
494
x = [letterbox (im , new_shape = shape1 , auto = False )[0 ] for im in imgs ] # pad
477
495
x = np .stack (x , 0 ) if n > 1 else x [0 ][None ] # stack
478
496
x = np .ascontiguousarray (x .transpose ((0 , 3 , 1 , 2 ))) # BHWC to BCHW
479
- x = torch .from_numpy (x ).to (p .device ).type_as (p ) / 255.0 # uint8 to fp16/32
497
+ x = torch .from_numpy (x ).to (p .device ).type_as (p ) / 255 # uint8 to fp16/32
480
498
t .append (time_sync ())
481
499
482
500
with amp .autocast (enabled = p .device .type != "cpu" ):
@@ -492,7 +510,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
492
510
classes = self .classes ,
493
511
multi_label = self .multi_label ,
494
512
max_det = self .max_det ,
495
- ) # NMS
513
+ )
496
514
for i in range (n ):
497
515
scale_coords (shape1 , y [i ][:, :4 ], shape0 [i ])
498
516
0 commit comments