Skip to content

Commit da9a1b7

Browse files
authored
Allow --weights URL (#5991)
1 parent b7d18f3 commit da9a1b7

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

models/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
296296
check_suffix(w, suffixes) # check weights have acceptable suffix
297297
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
298298
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
299-
attempt_download(w) # download if not local
299+
w = attempt_download(w) # download if not local
300300

301301
if jit: # TorchScript
302302
LOGGER.info(f'Loading {w} for TorchScript inference...')
@@ -306,7 +306,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
306306
d = json.loads(extra_files['config.txt']) # extra_files dict
307307
stride, names = int(d['stride']), d['names']
308308
elif pt: # PyTorch
309-
model = attempt_load(weights, map_location=device)
309+
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
310310
stride = int(model.stride.max()) # model stride
311311
names = model.module.names if hasattr(model, 'module') else model.names # get class names
312312
self.model = model # explicitly assign for to(), cpu(), cuda(), half()

utils/downloads.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads i
4949
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
5050
if str(file).startswith(('http:/', 'https:/')): # download
5151
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
52-
name = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
53-
safe_download(file=name, url=url, min_bytes=1E5)
54-
return name
52+
file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
53+
if Path(file).is_file():
54+
print(f'Found {url} locally at {file}') # file already exists
55+
else:
56+
safe_download(file=file, url=url, min_bytes=1E5)
57+
return file
5558

5659
# GitHub assets
5760
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)

0 commit comments

Comments
 (0)