Skip to content

Commit

Permalink
Merge master
Browse files Browse the repository at this point in the history
  • Loading branch information
zldrobit committed Dec 10, 2021
2 parents 177f01e + 4fb6dd4 commit 411d7cf
Show file tree
Hide file tree
Showing 18 changed files with 138 additions and 89 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ data/samples/*
**/*.pt
**/*.pth
**/*.onnx
**/*.engine
**/*.mlmodel
**/*.torchscript
**/*.torchscript.pt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)"
- name: Cache pip
uses: actions/cache@v2.1.6
uses: actions/cache@v2.1.7
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ VOC/
*.pt
*.pb
*.onnx
*.engine
*.mlmodel
*.torchscript
*.tflite
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ ADD https://ultralytics.com/assets/Arial.ttf /root/.config/Ultralytics/

# DDP test
# python -m torch.distributed.run --nproc_per_node 2 --master_port 1 train.py --epochs 3

# GCP VM from Image
# docker.io/ultralytics/yolov5:latest
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ $ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size
* [Roboflow for Datasets, Labeling, and Active Learning](https://github.com/ultralytics/yolov5/issues/4975)  🌟 NEW
* [Multi-GPU Training](https://github.com/ultralytics/yolov5/issues/475)
* [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36)  ⭐ NEW
* [TorchScript, ONNX, CoreML Export](https://github.com/ultralytics/yolov5/issues/251) 🚀
* [TFLite, ONNX, CoreML, TensorRT Export](https://github.com/ultralytics/yolov5/issues/251) 🚀
* [Test-Time Augmentation (TTA)](https://github.com/ultralytics/yolov5/issues/303)
* [Model Ensembling](https://github.com/ultralytics/yolov5/issues/318)
* [Model Pruning/Sparsity](https://github.com/ultralytics/yolov5/issues/304)
Expand Down
15 changes: 6 additions & 9 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
project=ROOT / 'runs/detect', # save results to project/name
name='exp', # save results to project/name
exist_ok=False, # existing project/name ok, do not increment
cfg='./models/yolov5s.yaml', # *.yaml config file
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
Expand All @@ -78,29 +77,28 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, cfg=cfg, data=data)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size

# Half
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt or jit:
model.model.half() if half else model.model.float()

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
Expand Down Expand Up @@ -226,7 +224,6 @@ def parse_opt():
parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--cfg', type=str, default='./models/yolov5s.yaml', help='cfg path')
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
Expand Down
24 changes: 12 additions & 12 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
"""
Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
Format | Example | `--include ...` argument
--- | --- | ---
PyTorch | yolov5s.pt | -
TorchScript | yolov5s.torchscript | `torchscript`
ONNX | yolov5s.onnx | `onnx`
CoreML | yolov5s.mlmodel | `coreml`
TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model`
TensorFlow GraphDef | yolov5s.pb | `pb`
TensorFlow Lite | yolov5s.tflite | `tflite`
TensorFlow Edge TPU | yolov5s-int8_edgetpu.tflite | `edgetpu`
TensorFlow.js | yolov5s_web_model/ | `tfjs`
TensorRT | yolov5s.engine | `engine'
Format | Example | `--include ...` argument
--- | --- | ---
PyTorch | yolov5s.pt | -
TorchScript | yolov5s.torchscript | `torchscript`
ONNX | yolov5s.onnx | `onnx`
CoreML | yolov5s.mlmodel | `coreml`
TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model`
TensorFlow GraphDef | yolov5s.pb | `pb`
TensorFlow Lite | yolov5s.tflite | `tflite`
TensorFlow Edge TPU | yolov5s-int8_edgetpu.tflite | `edgetpu`
TensorFlow.js | yolov5s_web_model/ | `tfjs`
TensorRT | yolov5s.engine | `engine'
Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite edgetpu tfjs
Expand Down
14 changes: 7 additions & 7 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Usage:
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
model = torch.hub.load('ultralytics/yolov5:master', 'custom', 'path/to/yolov5s.onnx') # file from branch
"""

import torch
Expand All @@ -27,26 +28,25 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
"""
from pathlib import Path

from models.common import AutoShape
from models.experimental import attempt_load
from models.common import AutoShape, DetectMultiBackend
from models.yolo import Model
from utils.downloads import attempt_download
from utils.general import check_requirements, intersect_dicts, set_logging
from utils.torch_utils import select_device

file = Path(__file__).resolve()
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
set_logging(verbose=verbose)

save_dir = Path('') if str(name).endswith('.pt') else file.parent
path = (save_dir / name).with_suffix('.pt') # checkpoint path
name = Path(name)
path = name.with_suffix('.pt') if name.suffix == '' else name # checkpoint path
try:
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)

if pretrained and channels == 3 and classes == 80:
model = attempt_load(path, map_location=device) # download/load FP32 model
model = DetectMultiBackend(path, device=device) # download/load FP32 model
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load
Expand Down
77 changes: 48 additions & 29 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import platform
import warnings
from collections import namedtuple
from collections import OrderedDict, namedtuple
from copy import copy
from pathlib import Path

Expand Down Expand Up @@ -277,10 +277,10 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=True, cfg=None, data=None):
def __init__(self, weights='yolov5s.pt', device=None, dnn=True, data=None):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript.pt
# TorchScript: *.torchscript
# CoreML: *.mlmodel
# TensorFlow: *_saved_model
# TensorFlow: *.pb
Expand All @@ -289,13 +289,16 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True, cfg=None, data=N
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import

super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
suffix = Path(w).suffix.lower()
suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
attempt_download(w) # download if not local

if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
Expand All @@ -305,11 +308,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True, cfg=None, data=N
d = json.loads(extra_files['config.txt']) # extra_files dict
stride, names = int(d['stride']), d['names']
elif pt: # PyTorch
from models.experimental import attempt_load # scoped to avoid circular import
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
model = attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
elif coreml: # CoreML *.mlmodel
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif coreml: # CoreML
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
model = ct.models.MLModel(w)
elif dnn: # ONNX OpenCV DNN
Expand All @@ -318,24 +322,26 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True, cfg=None, data=N
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
cuda = torch.cuda.is_available()
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(w, providers=providers)
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = dict()
bindings = OrderedDict()
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()}
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
Expand Down Expand Up @@ -386,7 +392,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
if self.pt: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
elif self.coreml: # CoreML *.mlmodel
elif self.coreml: # CoreML
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
Expand Down Expand Up @@ -432,29 +438,41 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y

def warmup(self, imgsz=(1, 3, 640, 640), half=False):
# Warmup model by running inference once
if self.pt or self.engine or self.onnx: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
self.forward(im) # warmup


class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
agnostic = False # NMS class-agnostic
multi_label = False # NMS multiple labels per box
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference

def __init__(self, model):
super().__init__()
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
self.model = model.eval()

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()
Expand All @@ -469,9 +487,10 @@ def forward(self, imgs, size=640, augment=False, profile=False):
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

t = [time_sync()]
p = next(self.model.parameters()) # for device and type
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(enabled=p.device.type != 'cpu'):
with amp.autocast(enabled=autocast):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
Expand All @@ -493,21 +512,21 @@ def forward(self, imgs, size=640, augment=False, profile=False):
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync())

with amp.autocast(enabled=p.device.type != 'cpu'):
with amp.autocast(enabled=autocast):
# Inference
y = self.model(x, augment, profile)[0] # forward
y = self.model(x, augment, profile) # forward
t.append(time_sync())

# Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
multi_label=self.multi_label, max_det=self.max_det) # NMS
y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes,
agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

Expand Down Expand Up @@ -604,7 +623,7 @@ def pandas(self):

def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
x = [Detections([self.imgs[i]], [self.pred[i]], names=self.names, shape=self.s) for i in range(self.n)]
for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list
Expand Down
Loading

0 comments on commit 411d7cf

Please sign in to comment.