diff --git a/storage/google/cloud/storage/blob.py b/storage/google/cloud/storage/blob.py index d2784d6e9ad6..3bf5228b997a 100644 --- a/storage/google/cloud/storage/blob.py +++ b/storage/google/cloud/storage/blob.py @@ -478,8 +478,13 @@ def download_to_filename(self, filename, client=None): :raises: :class:`google.cloud.exceptions.NotFound` """ - with open(filename, 'wb') as file_obj: - self.download_to_file(file_obj, client=client) + try: + with open(filename, 'wb') as file_obj: + self.download_to_file(file_obj, client=client) + except resumable_media.DataCorruption as exc: + # Delete the corrupt downloaded file. + os.remove(filename) + raise updated = self.updated if updated is not None: diff --git a/storage/setup.py b/storage/setup.py index 3d0ea59632d3..7aa64bafe114 100644 --- a/storage/setup.py +++ b/storage/setup.py @@ -53,7 +53,7 @@ REQUIREMENTS = [ 'google-cloud-core >= 0.27.0, < 0.28dev', 'google-auth >= 1.0.0', - 'google-resumable-media >= 0.2.3', + 'google-resumable-media >= 0.3.0', 'requests >= 2.18.0', ] diff --git a/storage/tests/unit/test_blob.py b/storage/tests/unit/test_blob.py index e0a41ee793d2..8f0f2ad80132 100644 --- a/storage/tests/unit/test_blob.py +++ b/storage/tests/unit/test_blob.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import datetime +import hashlib import io import json import os +import tempfile import unittest import mock @@ -607,6 +610,78 @@ def test_download_to_filename(self): def test_download_to_filename_wo_updated(self): self._download_to_filename_helper() + def test_download_to_filename_corrupted(self): + from google.resumable_media import DataCorruption + from google.resumable_media.requests.download import _CHECKSUM_MISMATCH + + blob_name = 'blob-name' + transport = mock.Mock(spec=['request']) + empty_hash = base64.b64encode( + hashlib.md5(b'').digest()).decode(u'utf-8') + headers = {'x-goog-hash': 'md5=' + empty_hash} + response = mock.MagicMock( + headers=headers, + status_code=http_client.OK, + spec=[ + '__enter__', + '__exit__', + 'headers', + 'iter_content', + 'status_code', + ], + ) + # i.e. context manager returns ``self``. + response.__enter__.return_value = response + response.__exit__.return_value = None + chunks = (b'noms1', b'coooookies2') + response.iter_content.return_value = iter(chunks) + + transport.request.return_value = response + # Create a fake client/bucket and use them in the Blob() constructor. + client = mock.Mock(_http=transport, spec=['_http']) + bucket = mock.Mock( + client=client, + user_project=None, + spec=['client', 'user_project'], + ) + media_link = 'http://example.com/media/' + properties = {'mediaLink': media_link} + blob = self._make_one(blob_name, bucket=bucket, properties=properties) + # Make sure the download is **not** chunked. + self.assertIsNone(blob.chunk_size) + + # Make sure the hash will be wrong. + content = b''.join(chunks) + expected_hash = base64.b64encode( + hashlib.md5(content).digest()).decode(u'utf-8') + self.assertNotEqual(empty_hash, expected_hash) + + # Try to download into a temporary file (don't use + # `_NamedTemporaryFile` it will try to remove after the file is + # already removed) + filehandle, filename = tempfile.mkstemp() + os.close(filehandle) + with self.assertRaises(DataCorruption) as exc_info: + blob.download_to_filename(filename) + + msg = _CHECKSUM_MISMATCH.format(media_link, empty_hash, expected_hash) + self.assertEqual(exc_info.exception.args, (msg,)) + # Make sure the file was cleaned up. + self.assertFalse(os.path.exists(filename)) + + # Check the mocks. + response.__enter__.assert_called_once_with() + response.__exit__.assert_called_once_with(None, None, None) + response.iter_content.assert_called_once_with( + chunk_size=8192, decode_unicode=False) + transport.request.assert_called_once_with( + 'GET', + media_link, + data=None, + headers={'accept-encoding': 'gzip'}, + stream=True, + ) + def test_download_to_filename_w_key(self): import os import time