Skip to content

Commit

Permalink
internal.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 230335295
  • Loading branch information
pierrot0 authored and Copybara-Service committed Jan 22, 2019
1 parent 7c638f9 commit 949414c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,5 @@ def _wait_on_promise(p):
def _map_promise(map_fn, all_inputs):
"""Map the function into each element and resolve the promise."""
all_promises = utils.map_nested(map_fn, all_inputs) # Apply the function
return utils.map_nested(_wait_on_promise, all_promises)
res = utils.map_nested(_wait_on_promise, all_promises)
return res
31 changes: 20 additions & 11 deletions tensorflow_datasets/core/download/download_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import promise
import tensorflow as tf
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.download import download_manager as dm
from tensorflow_datasets.core.download import resource as resource_lib

Expand Down Expand Up @@ -150,11 +151,12 @@ def test_download(self):
downloaded_b.set()
downloaded_c.set()
downloads = manager.download(urls)
self.assertEqual(downloads, {
expected = {
'cached': '/dl_dir/%s' % afname,
'new': '/dl_dir/%s' % bfname,
'info_deleted': '/dl_dir/%s' % cfname,
})
}
self.assertEqual(downloads, expected)

def test_extract(self):
"""One file already extracted, one file with NO_EXTRACT, one to extract."""
Expand All @@ -174,11 +176,12 @@ def test_extract(self):
manager = self._get_manager()
extracted_new.set()
res = manager.extract(files)
self.assertEqual(res, {
expected = {
'cached': '/extract_dir/ZIP.%s' % resource_cached.fname,
'new': '/extract_dir/TAR.%s' % resource_new.fname,
'noextract': '/dl_dir/%s' % resource_noextract.fname,
})
}
self.assertEqual(res, expected)

def test_extract_twice_parallel(self):
# Make sure calling extract twice on same resource actually does the
Expand All @@ -189,9 +192,11 @@ def test_extract_twice_parallel(self):
extracted_new.set()
out1 = manager.extract(['/dl_dir/foo.tar', '/dl_dir/foo.tar'])
out2 = manager.extract('/dl_dir/foo.tar')
self.assertEqual(out1[0], '/extract_dir/TAR.foo')
self.assertEqual(out1[1], '/extract_dir/TAR.foo')
self.assertEqual(out2, '/extract_dir/TAR.foo')
expected = '/extract_dir/TAR.foo'
self.assertEqual(out1[0], expected)
self.assertEqual(out1[1], expected)
expected = '/extract_dir/TAR.foo'
self.assertEqual(out2, expected)
# Result is memoize so extract has only been called once
self.assertEqual(1, self.extractor_extract.call_count)

Expand All @@ -216,9 +221,11 @@ def test_download_and_extract(self):
manager._checksums[url_a] = sha_contenta
manager._checksums[url_b] = sha_contentb
res = manager.download_and_extract({'a': url_a, 'b': url_b})
self.assertEqual(res, {
expected = {
'a': '/extract_dir/ZIP.%s' % resource_a.fname,
'b': '/dl_dir/%s' % resource_b.fname})
'b': '/dl_dir/%s' % resource_b.fname,
}
self.assertEqual(res, expected)

def test_download_and_extract_already_downloaded(self):
url_a = 'http://a/a.zip'
Expand All @@ -233,7 +240,8 @@ def test_download_and_extract_already_downloaded(self):
ext_a.set()
manager = self._get_manager()
res = manager.download_and_extract(url_a)
self.assertEqual(res, '/extract_dir/ZIP.%s' % resource_a.fname)
expected = '/extract_dir/ZIP.%s' % resource_a.fname
self.assertEqual(res, expected)

def test_force_download_and_extract(self):
url = 'http://a/b.tar.gz'
Expand All @@ -253,7 +261,8 @@ def test_force_download_and_extract(self):
manager = self._get_manager(force_download=True, force_extraction=True,
checksums={url: resource_.sha256})
res = manager.download_and_extract(url)
self.assertEqual('/extract_dir/TAR_GZ.%s' % resource_.fname, res)
expected = '/extract_dir/TAR_GZ.%s' % resource_.fname
self.assertEqual(expected, res)
# Rename after download:
(from_, to), kwargs = self.gfile.rename.call_args
self.assertTrue(re.match(
Expand Down

0 comments on commit 949414c

Please sign in to comment.