Skip to content

Commit 52d1943

Browse files
authored
Update TorchScript suffix to *.torchscript (ultralytics#5856)
1 parent 1b46244 commit 52d1943

File tree

5 files changed

+20
-20
lines changed

5 files changed

+20
-20
lines changed

detect.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,18 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
8181
imgsz = check_img_size(imgsz, s=stride) # check image size
8282

8383
# Half
84-
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
85-
if pt:
84+
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
85+
if pt or jit:
8686
model.model.half() if half else model.model.float()
8787

8888
# Dataloader
8989
if webcam:
9090
view_img = check_imshow()
9191
cudnn.benchmark = True # set True to speed up constant image size inference
92-
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
92+
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
9393
bs = len(dataset) # batch_size
9494
else:
95-
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
95+
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
9696
bs = 1 # batch_size
9797
vid_path, vid_writer = [None] * bs, [None] * bs
9898

export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Format | Example | Export `include=(...)` argument
66
--- | --- | ---
77
PyTorch | yolov5s.pt | -
8-
TorchScript | yolov5s.torchscript.pt | 'torchscript'
8+
TorchScript | yolov5s.torchscript | 'torchscript'
99
ONNX | yolov5s.onnx | 'onnx'
1010
CoreML | yolov5s.mlmodel | 'coreml'
1111
TensorFlow SavedModel | yolov5s_saved_model/ | 'saved_model'
@@ -19,7 +19,7 @@
1919
2020
Inference:
2121
$ python path/to/detect.py --weights yolov5s.pt
22-
yolov5s.torchscript.pt
22+
yolov5s.torchscript
2323
yolov5s.onnx
2424
yolov5s.mlmodel (under development)
2525
yolov5s_saved_model
@@ -66,7 +66,7 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
6666
# YOLOv5 TorchScript model export
6767
try:
6868
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
69-
f = file.with_suffix('.torchscript.pt')
69+
f = file.with_suffix('.torchscript')
7070

7171
ts = torch.jit.trace(model, im, strict=False)
7272
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}

models/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class DetectMultiBackend(nn.Module):
279279
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
280280
# Usage:
281281
# PyTorch: weights = *.pt
282-
# TorchScript: *.torchscript.pt
282+
# TorchScript: *.torchscript
283283
# CoreML: *.mlmodel
284284
# TensorFlow: *_saved_model
285285
# TensorFlow: *.pb
@@ -289,10 +289,10 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
289289
# TensorRT: *.engine
290290
super().__init__()
291291
w = str(weights[0] if isinstance(weights, list) else weights)
292-
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
292+
suffix = Path(w).suffix.lower()
293+
suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
293294
check_suffix(w, suffixes) # check weights have acceptable suffix
294-
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
295-
jit = pt and 'torchscript' in w.lower()
295+
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
296296
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
297297

298298
if jit: # TorchScript
@@ -304,10 +304,10 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
304304
stride, names = int(d['stride']), d['names']
305305
elif pt: # PyTorch
306306
from models.experimental import attempt_load # scoped to avoid circular import
307-
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
307+
model = attempt_load(weights, map_location=device)
308308
stride = int(model.stride.max()) # model stride
309309
names = model.module.names if hasattr(model, 'module') else model.names # get class names
310-
elif coreml: # CoreML *.mlmodel
310+
elif coreml: # CoreML
311311
import coremltools as ct
312312
model = ct.models.MLModel(w)
313313
elif dnn: # ONNX OpenCV DNN

utils/activations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def forward(x):
1818
class Hardswish(nn.Module): # export-friendly version of nn.Hardswish()
1919
@staticmethod
2020
def forward(x):
21-
# return x * F.hardsigmoid(x) # for torchscript and CoreML
22-
return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for torchscript, CoreML and ONNX
21+
# return x * F.hardsigmoid(x) # for TorchScript and CoreML
22+
return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX
2323

2424

2525
# Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------

val.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def run(data,
111111
# Initialize/load model and set device
112112
training = model is not None
113113
if training: # called by train.py
114-
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model
114+
device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
115115

116116
half &= device.type != 'cpu' # half precision only supported on CUDA
117117
model.half() if half else model.float()
@@ -124,10 +124,10 @@ def run(data,
124124

125125
# Load model
126126
model = DetectMultiBackend(weights, device=device, dnn=dnn)
127-
stride, pt, engine = model.stride, model.pt, model.engine
127+
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
128128
imgsz = check_img_size(imgsz, s=stride) # check image size
129-
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
130-
if pt:
129+
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
130+
if pt or jit:
131131
model.model.half() if half else model.model.float()
132132
elif engine:
133133
batch_size = model.batch_size
@@ -166,7 +166,7 @@ def run(data,
166166
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
167167
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
168168
t1 = time_sync()
169-
if pt or engine:
169+
if pt or jit or engine:
170170
im = im.to(device, non_blocking=True)
171171
targets = targets.to(device)
172172
im = im.half() if half else im.float() # uint8 to fp16/32

0 commit comments

Comments
 (0)