Skip to content

Commit 8414e75

Browse files
authored
Refactor new model.warmup() method (ultralytics#5810)
* Refactor new `model.warmup()` method * Add half
1 parent ca9ad37 commit 8414e75

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

detect.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
9797
vid_path, vid_writer = [None] * bs, [None] * bs
9898

9999
# Run inference
100-
if pt and device.type != 'cpu':
101-
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
100+
model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
102101
dt, seen = [0.0, 0.0, 0.0], 0
103102
for path, im, im0s, vid_cap, s in dataset:
104103
t1 = time_sync()

models/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,13 @@ def forward(self, im, augment=False, visualize=False, val=False):
421421
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
422422
return (y, []) if val else y
423423

424+
def warmup(self, imgsz=(1, 3, 640, 640), half=False):
425+
# Warmup model by running inference once
426+
if self.pt or self.engine or self.onnx: # warmup types
427+
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
428+
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
429+
self.forward(im) # warmup
430+
424431

425432
class AutoShape(nn.Module):
426433
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS

val.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ def run(data,
149149

150150
# Dataloader
151151
if not training:
152-
if pt and device.type != 'cpu':
153-
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
152+
model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
154153
pad = 0.0 if task == 'speed' else 0.5
155154
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
156155
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,

0 commit comments

Comments
 (0)