Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wget option in download #33379

Merged
merged 6 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/paddle/hapi/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
url = _git_archive_link(repo_owner, repo_name, branch, source=source)

fpath = get_path_from_url(
url, hub_dir, check_exist=not force_reload, decompress=False)
url,
hub_dir,
check_exist=not force_reload,
decompress=False,
method=('wget' if source == 'gitee' else 'get'))
shutil.move(fpath, cached_file)

with zipfile.ZipFile(cached_file) as cached_zipfile:
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ def test_retry_exception(self, ):
'www.baidu.com',
'./test', )

def test_wget_download_error(self, ):
with self.assertRaises(RuntimeError):
from paddle.utils.download import _download
_download('www.baidu', './test', method='wget')

def test_download_methods(self, ):
urls = [
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
]

import sys
from paddle.utils.download import _download
if sys.platform == 'linux':
methods = ['wget', 'get']
else:
methods = ['get']

for url in urls:
for method in methods:
_download(
url,
path='./test',
method=method, )


if __name__ == '__main__':
unittest.main()
109 changes: 76 additions & 33 deletions python/paddle/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os.path as osp
import shutil
import requests
import subprocess
import hashlib
import tarfile
import zipfile
Expand Down Expand Up @@ -121,7 +122,8 @@ def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True):
decompress=True,
method='get'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments for decompress and method. Add candidates for method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
Expand All @@ -132,7 +134,9 @@ def get_path_from_url(url,
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package

decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.

Returns:
str: a local path to save downloaded models & weights & datasets.
"""
Expand All @@ -150,7 +154,7 @@ def get_path_from_url(url,
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum)
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
Expand All @@ -163,59 +167,98 @@ def get_path_from_url(url,
return fullpath


def _download(url, path, md5sum=None):
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False

if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))

# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)

return fullname


def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# –user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()

if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))

shutil.move(tmp_fullname, fullname)

return fullname


_download_methods = {
'get': _get_download,
'wget': _wget_download,
}


def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.

url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.

"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)

if not osp.exists(path):
os.makedirs(path)

fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0

logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))

logger.info("Downloading {} from {}".format(fname, url))

try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info(
"Downloading {} from {} failed {} times with exception {}".
format(fname, url, retry_cnt + 1, str(e)))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue

if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))

# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)

return fullname


Expand Down