Skip to content

Commit

Permalink
Add progression bar during download/extraction
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 229893839
  • Loading branch information
Conchylicultor authored and Copybara-Service committed Jan 18, 2019
1 parent 5b925ee commit b3e1438
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 46 deletions.
45 changes: 14 additions & 31 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,20 @@ def callback(path):
return self._extract(resource)
return self._download(resource).then(callback)

def download(self, url_or_urls, async_=False):
def download(self, url_or_urls):
"""Download given url(s).
Args:
url_or_urls: url or `list`/`dict` of urls to download and extract. Each
url can be a `str` or `tfds.download.Resource`.
async_: `bool`, default to False. If True, returns promise on result.
Returns:
downloaded_path(s): `str`, The downloaded paths matching the given input
url_or_urls.
"""
return _map_promise(self._download, url_or_urls, async_=async_)
# Add progress bar to follow the download state
with self._downloader.tqdm():
return _map_promise(self._download, url_or_urls)

def iter_archive(self, resource):
"""Returns iterator over files within archive.
Expand All @@ -290,13 +291,12 @@ def iter_archive(self, resource):
resource = resource_lib.Resource(path=resource)
return extractor.iter_archive(resource.path, resource.extract_method)

def extract(self, path_or_paths, async_=False):
def extract(self, path_or_paths):
"""Extract given path(s).
Args:
path_or_paths: path or `list`/`dict` of path of file to extract. Each
path can be a `str` or `tfds.download.Resource`.
async_: `bool`, default to False. If True, returns promise on result.
If not explicitly specified in `Resource`, the extraction method is deduced
from downloaded file name.
Expand All @@ -305,9 +305,11 @@ def extract(self, path_or_paths, async_=False):
extracted_path(s): `str`, The extracted paths matching the given input
path_or_paths.
"""
return _map_promise(self._extract, path_or_paths, async_=async_)
# Add progress bar to follow the download state
with self._extractor.tqdm():
return _map_promise(self._extract, path_or_paths)

def download_and_extract(self, url_or_urls, async_=False):
def download_and_extract(self, url_or_urls):
"""Download and extract given url_or_urls.
Is roughly equivalent to:
Expand All @@ -319,15 +321,17 @@ def download_and_extract(self, url_or_urls, async_=False):
Args:
url_or_urls: url or `list`/`dict` of urls to download and extract. Each
url can be a `str` or `tfds.download.Resource`.
async_: `bool`, defaults to False. If True, returns promise on result.
If not explicitly specified in `Resource`, the extraction method will
automatically be deduced from downloaded file name.
Returns:
extracted_path(s): `str`, extracted paths of given URL(s).
"""
return _map_promise(self._download_extract, url_or_urls, async_=async_)
# Add progress bar to follow the download state
with self._downloader.tqdm():
with self._extractor.tqdm():
return _map_promise(self._download_extract, url_or_urls)

@property
def manual_dir(self):
Expand Down Expand Up @@ -357,31 +361,10 @@ def _wait_on_promise(p):
if p.is_fulfilled:
return result

class _KillablePromise(promise.Promise):

def __init__(self, promise_):
self._promise = promise_

def get(self, timeout=None):
if timeout is not None:
return self._promise.get(timeout)
return _wait_on_promise(self._promise)
# ============================================================================


def _map_promise(map_fn, all_inputs, async_):
def _map_promise(map_fn, all_inputs):
"""Map the function into each element and resolve the promise."""
all_promises = utils.map_nested(map_fn, all_inputs) # Apply the function
if async_:
# TODO(tfds): Fix for nested case
if isinstance(all_promises, dict):
merged_promise = promise.Promise.for_dict(all_promises)
elif isinstance(all_promises, list):
merged_promise = promise.Promise.all(all_promises)
else:
merged_promise = all_promises
if sys.version_info[0] == 2:
merged_promise = _KillablePromise(merged_promise)
return merged_promise

return utils.map_nested(_wait_on_promise, all_promises)
23 changes: 9 additions & 14 deletions tensorflow_datasets/core/download/download_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import json
import os
import re
import sys
import tempfile
import threading

Expand Down Expand Up @@ -144,11 +143,9 @@ def test_download(self):
downloaded_c, self.dl_results['https://a.ch/c'] = _get_promise_on_event(
('sha_c', 10))
manager = self._get_manager()
res = manager.download(urls, async_=True)
self.assertFalse(res.is_fulfilled)
downloaded_b.set()
downloaded_c.set()
downloads = res.get()
downloads = manager.download(urls)
self.assertEqual(downloads, {
'cached': '/dl_dir/%s' % afname,
'new': '/dl_dir/%s' % bfname,
Expand All @@ -171,10 +168,9 @@ def test_extract(self):
extracted_new, self.extract_results['/dl_dir/%s' % resource_new.fname] = (
_get_promise_on_event('/extract_dir/TAR.new'))
manager = self._get_manager()
res = manager.extract(files, async_=True)
self.assertFalse(res.is_fulfilled)
extracted_new.set()
self.assertEqual(res.get(), {
res = manager.extract(files)
self.assertEqual(res, {
'cached': '/extract_dir/ZIP.%s' % resource_cached.fname,
'new': '/extract_dir/TAR.%s' % resource_new.fname,
'noextract': '/dl_dir/%s' % resource_noextract.fname,
Expand All @@ -186,14 +182,13 @@ def test_extract_twice_parallel(self):
extracted_new, self.extract_results['/dl_dir/foo.tar'] = (
_get_promise_on_event('/extract_dir/TAR.foo'))
manager = self._get_manager()
res1 = manager.extract('/dl_dir/foo.tar', async_=True)
res2 = manager.extract('/dl_dir/foo.tar', async_=True)
if sys.version_info[0] > 2:
self.assertTrue(res1 is res2)
else:
self.assertTrue(res1._promise is res2._promise)
extracted_new.set()
self.assertEqual(res1.get(), '/extract_dir/TAR.foo')
out1 = manager.extract(['/dl_dir/foo.tar', '/dl_dir/foo.tar'])
out2 = manager.extract('/dl_dir/foo.tar')
self.assertEqual(out1[0], '/extract_dir/TAR.foo')
self.assertEqual(out1[1], '/extract_dir/TAR.foo')
self.assertEqual(out2, '/extract_dir/TAR.foo')
# Result is memoize so extract has only been called once
self.assertEqual(1, self.extractor_extract.call_count)

def test_download_and_extract(self):
Expand Down
29 changes: 29 additions & 0 deletions tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import division
from __future__ import print_function

import contextlib
import hashlib
import io
import os
Expand All @@ -29,6 +30,7 @@
import requests
from tensorflow import gfile

from tensorflow_datasets.core import units
from tensorflow_datasets.core.download import util
from tensorflow_datasets.core.utils import py_utils

Expand Down Expand Up @@ -69,6 +71,18 @@ def __init__(self, max_simultaneous_downloads=50, checksumer=None):
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=max_simultaneous_downloads)
self._checksumer = checksumer or hashlib.sha256
self._pbar_url = None
self._pbar_dl_size = None

@contextlib.contextmanager
def tqdm(self):
"""Add a progression bar for the current download."""
async_tqdm = py_utils.async_tqdm
with async_tqdm(total=0, desc='Dl Completed...', unit=' url') as pbar_url:
with async_tqdm(total=0, desc='Dl Size...', unit=' MiB') as pbar_dl_size:
self._pbar_url = pbar_url
self._pbar_dl_size = pbar_dl_size
yield

def download(self, url_info, destination_path):
"""Download url to given path. Returns Promise -> sha256 of downloaded file.
Expand All @@ -80,6 +94,7 @@ def download(self, url_info, destination_path):
Returns:
Promise obj -> (`str`, int): (downloaded object checksum, size in bytes).
"""
self._pbar_url.update_total(1)
url = url_info.url
future = self._executor.submit(self._sync_download, url, destination_path)
return promise.Promise.resolve(future)
Expand Down Expand Up @@ -109,11 +124,25 @@ def _sync_download(self, url, destination_path):
fname = _get_filename(response)
path = os.path.join(destination_path, fname)
size = 0

size_mb = 0
unit_mb = units.MiB
self._pbar_dl_size.update_total(
int(response.headers.get('Content-length', 0)) // unit_mb
)
with gfile.Open(path, 'wb') as file_:
for block in response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE):
size += len(block)

# Update the progress bar
size_mb += len(block)
if size_mb > unit_mb:
self._pbar_dl_size.update(size_mb // unit_mb)
size_mb %= unit_mb

checksum.update(block)
# TODO(pierrot): Test this is faster than doing checksum in the end
# and document results here.
file_.write(block)
self._pbar_url.update(1)
return checksum.hexdigest(), size
8 changes: 7 additions & 1 deletion tensorflow_datasets/core/download/downloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, url, content, cookies=None, headers=None, status_code=200):
self.url = url
self.raw = io.BytesIO(content)
self.cookies = cookies or {}
self.headers = headers or {}
self.headers = headers or {'Content-length': 12345}
self.status_code = status_code

def iter_content(self, chunk_size):
Expand All @@ -61,6 +61,12 @@ def setUp(self):
downloader.requests.Session, 'get',
lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies),
).start()
tf.test.mock.patch.object(
downloader.requests.Session, 'get',
lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies),
).start()
self.downloader._pbar_url = tf.test.mock.MagicMock()
self.downloader._pbar_dl_size = tf.test.mock.MagicMock()

def test_ok(self):
promise = self.downloader.download(self.resource, self.tmp_dir)
Expand Down
11 changes: 11 additions & 0 deletions tensorflow_datasets/core/download/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,19 @@ class _Extractor(object):
def __init__(self, max_workers=12):
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers)
self._pbar_path = None

@contextlib.contextmanager
def tqdm(self):
"""Add a progression bar for the current extraction."""
with py_utils.async_tqdm(
total=0, desc='Extraction completed...', unit=' file') as pbar_path:
self._pbar_path = pbar_path
yield

def extract(self, resource, to_path):
"""Returns `promise.Promise` => to_path."""
self._pbar_path.update_total(1)
if resource.extract_method not in _EXTRACT_METHODS:
raise ValueError('Unknonw extraction method "%s".' %
resource.extract_method)
Expand All @@ -86,6 +96,7 @@ def _sync_extract(self, resource, to_path):
if tf.gfile.Exists(to_path):
tf.gfile.DeleteRecursively(to_path)
tf.gfile.Rename(to_path_tmp, to_path)
self._pbar_path.update(1)
return to_path


Expand Down
1 change: 1 addition & 0 deletions tensorflow_datasets/core/download/extractor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def setUpClass(cls):
def setUp(self):
super(ExtractorTest, self).setUp()
self.extractor = extractor.get_extractor()
self.extractor._pbar_path = tf.test.mock.MagicMock()
# Where archive will be extracted:
self.to_path = os.path.join(self.tmp_dir, 'extracted_arch')
# Obviously it must not exist before test runs:
Expand Down
60 changes: 60 additions & 0 deletions tensorflow_datasets/core/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import tensorflow as tf
from tensorflow_datasets.core import constants

import tqdm


# pylint: disable=g-import-not-at-top
if sys.version_info[0] > 2:
Expand Down Expand Up @@ -270,3 +272,61 @@ def reraise(additional_msg):
exc_type, exc_value, exc_traceback = sys.exc_info()
msg = str(exc_value) + "\n" + additional_msg
six.reraise(exc_type, exc_type(msg), exc_traceback)


@contextlib.contextmanager
def async_tqdm(*args, **kwargs):
"""Wrapper around Tqdm which can be updated in threads.
Usage:
```
with utils.async_tqdm(...) as pbar:
# pbar can then be modified inside a thread
# pbar.update_total(3)
# pbar.update()
```
Args:
*args: args of tqdm
**kwargs: kwargs of tqdm
Yields:
pbar: Async pbar which can be shared between threads.
"""
with tqdm.tqdm(*args, **kwargs) as pbar:
pbar = _TqdmPbarAsync(pbar)
yield pbar
pbar.clear() # pop pbar from the active list of pbar
print() # Avoid the next log to overlapp with the bar


class _TqdmPbarAsync(object):
"""Wrapper around Tqdm pbar which be shared between thread."""
_tqdm_bars = []

def __init__(self, pbar):
self._lock = tqdm.tqdm.get_lock()
self._pbar = pbar
self._tqdm_bars.append(pbar)

def update_total(self, n=1):
"""Increment total pbar value."""
with self._lock:
self._pbar.total += n
self.refresh()

def update(self, n=1):
"""Increment current value."""
with self._lock:
self._pbar.update(n)
self.refresh()

def refresh(self):
"""Refresh all."""
for pbar in self._tqdm_bars:
pbar.refresh()

def clear(self):
"""Remove the tqdm pbar from the update."""
self._tqdm_bars.pop()

0 comments on commit b3e1438

Please sign in to comment.