Skip to content

Commit

Permalink
[Storage] Add progress callback to download_blob methods (Azure#24276)
Browse files Browse the repository at this point in the history
  • Loading branch information
jalauzon-msft authored May 4, 2022
1 parent c59eb64 commit 6a83ffa
Show file tree
Hide file tree
Showing 12 changed files with 1,517 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,11 @@ def download_blob(self, offset=None, length=None, **kwargs):
The number of parallel connections with which to download.
:keyword str encoding:
Encoding to decode the downloaded bytes. Default is None, i.e. no decoding.
:keyword progress_hook:
A callback to track the progress of a long running download. The signature is
function(current: int, total: int) where current is the number of bytes transfered
so far, and total is the total size of the download.
:paramtype progress_hook: Callable[[int, int], None]
:keyword int timeout:
The timeout parameter is expressed in seconds. This method may make
multiple calls to the Azure service and the timeout will apply to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,11 @@ def download_blob(self, blob, offset=None, length=None, **kwargs):
The number of parallel connections with which to download.
:keyword str encoding:
Encoding to decode the downloaded bytes. Default is None, i.e. no decoding.
:keyword progress_hook:
A callback to track the progress of a long running download. The signature is
function(current: int, total: int) where current is the number of bytes transfered
so far, and total is the total size of the download.
:paramtype progress_hook: Callable[[int, int], None]
:keyword int timeout:
The timeout parameter is expressed in seconds. This method may make
multiple calls to the Azure service and the timeout will apply to
Expand Down
13 changes: 11 additions & 2 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import warnings
from io import BytesIO
from typing import Iterator
from typing import Iterator, Union

import requests
from azure.core.exceptions import HttpResponseError, ServiceResponseError
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
parallel=None,
validate_content=None,
encryption_options=None,
progress_hook=None,
**kwargs
):
self.client = client
Expand All @@ -96,6 +97,7 @@ def __init__(
self.stream = stream
self.stream_lock = threading.Lock() if parallel else None
self.progress_lock = threading.Lock() if parallel else None
self.progress_hook = progress_hook

# For a parallel download, the stream is always seekable, so we note down the current position
# in order to seek to the right place when out-of-order chunks come in
Expand Down Expand Up @@ -143,6 +145,9 @@ def _update_progress(self, length):
else:
self.progress_total += length

if self.progress_hook:
self.progress_hook(self.progress_total, self.total_size)

def _write_to_stream(self, chunk_data, chunk_start):
if self.stream_lock:
with self.stream_lock: # pylint: disable=not-context-manager
Expand Down Expand Up @@ -322,6 +327,7 @@ def __init__(
self._encoding = encoding
self._validate_content = validate_content
self._encryption_options = encryption_options or {}
self._progress_hook = kwargs.pop('progress_hook', None)
self._request_options = kwargs
self._location_mode = None
self._download_complete = False
Expand Down Expand Up @@ -514,7 +520,6 @@ def readall(self):
"""Download the contents of this blob.
This operation is blocking until all data is downloaded.
:rtype: bytes or str
"""
stream = BytesIO()
Expand Down Expand Up @@ -583,6 +588,9 @@ def readinto(self, stream):

# Write the content to the user stream
stream.write(self._current_content)
if self._progress_hook:
self._progress_hook(len(self._current_content), self.size)

if self._download_complete:
return self.size

Expand All @@ -604,6 +612,7 @@ def readinto(self, stream):
validate_content=self._validate_content,
encryption_options=self._encryption_options,
use_location=self._location_mode,
progress_hook=self._progress_hook,
**self._request_options
)
if parallel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ async def download_blob(self, offset=None, length=None, **kwargs):
The number of parallel connections with which to download.
:keyword str encoding:
Encoding to decode the downloaded bytes. Default is None, i.e. no decoding.
:keyword progress_hook:
An async callback to track the progress of a long running download. The signature is
function(current: int, total: int) where current is the number of bytes transfered
so far, and total is the total size of the download.
:paramtype progress_hook: Callable[[int, int], Awaitable[None]]
:keyword int timeout:
The timeout parameter is expressed in seconds. This method may make
multiple calls to the Azure service and the timeout will apply to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,11 @@ async def download_blob(self, blob, offset=None, length=None, **kwargs):
The number of parallel connections with which to download.
:keyword str encoding:
Encoding to decode the downloaded bytes. Default is None, i.e. no decoding.
:keyword progress_hook:
An async callback to track the progress of a long running download. The signature is
function(current: int, total: int) where current is the number of bytes transfered
so far, and total is the total size of the download.
:paramtype progress_hook: Callable[[int, int], Awaitable[None]]
:keyword int timeout:
The timeout parameter is expressed in seconds. This method may make
multiple calls to the Azure service and the timeout will apply to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ async def _update_progress(self, length):
else:
self.progress_total += length

if self.progress_hook:
await self.progress_hook(self.progress_total, self.total_size)

async def _write_to_stream(self, chunk_data, chunk_start):
if self.stream_lock:
async with self.stream_lock: # pylint: disable=not-async-context-manager
Expand Down Expand Up @@ -220,6 +223,7 @@ def __init__(
self._encoding = encoding
self._validate_content = validate_content
self._encryption_options = encryption_options or {}
self._progress_hook = kwargs.pop('progress_hook', None)
self._request_options = kwargs
self._location_mode = None
self._download_complete = False
Expand Down Expand Up @@ -472,6 +476,9 @@ async def readinto(self, stream):

# Write the content to the user stream
stream.write(self._current_content)
if self._progress_hook:
await self._progress_hook(len(self._current_content), self.size)

if self._download_complete:
return self.size

Expand All @@ -493,6 +500,7 @@ async def readinto(self, stream):
validate_content=self._validate_content,
encryption_options=self._encryption_options,
use_location=self._location_mode,
progress_hook=self._progress_hook,
**self._request_options)

dl_tasks = downloader.get_chunk_offsets()
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

108 changes: 99 additions & 9 deletions sdk/storage/azure-storage-blob/tests/test_get_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,24 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
import base64
import unittest
import pytest
import uuid
from os import path, remove, sys, urandom
from io import BytesIO
from os import path, remove
from azure.core.exceptions import HttpResponseError
from devtools_testutils import ResourceGroupPreparer, StorageAccountPreparer

from azure.storage.blob import (
BlobServiceClient,
ContainerClient,
BlobClient,
StorageErrorCode,
BlobProperties
)
from settings.testcase import BlobPreparer
from devtools_testutils.storage import StorageTestCase
from settings.testcase import BlobPreparer
from test_helpers import ProgressTracker

# ------------------------------------------------------------------------------
TEST_BLOB_PREFIX = 'blob'


# ------------------------------------------------------------------------------

class StorageGetBlobTest(StorageTestCase):
Expand Down Expand Up @@ -928,5 +924,99 @@ def test_get_blob_range_with_range_md5(self, storage_account_name, storage_accou
self.assertIsNotNone(content.properties.content_settings.content_type)
self.assertIsNone(content.properties.content_settings.content_md5)

@BlobPreparer()
def test_get_blob_progress_single_get(self, storage_account_name, storage_account_key):
self._setup(storage_account_name, storage_account_key)
data = b'a' * 512
blob_name = self._get_blob_reference()
blob = self.bsc.get_blob_client(self.container_name, blob_name)
blob.upload_blob(data, overwrite=True)

progress = ProgressTracker(len(data), len(data))

# Act
blob.download_blob(progress_hook=progress.assert_progress).readall()

# Assert
progress.assert_complete()

@BlobPreparer()
def test_get_blob_progress_chunked(self, storage_account_name, storage_account_key):
self._setup(storage_account_name, storage_account_key)
data = b'a' * 5120
blob_name = self._get_blob_reference()
blob = self.bsc.get_blob_client(self.container_name, blob_name)
blob.upload_blob(data, overwrite=True)

progress = ProgressTracker(len(data), 1024)

# Act
blob.download_blob(max_concurrency=1, progress_hook=progress.assert_progress).readall()

# Assert
progress.assert_complete()

@pytest.mark.live_test_only
@BlobPreparer()
def test_get_blob_progress_chunked_parallel(self, storage_account_name, storage_account_key):
# parallel tests introduce random order of requests, can only run live
self._setup(storage_account_name, storage_account_key)
data = b'a' * 5120
blob_name = self._get_blob_reference()
blob = self.bsc.get_blob_client(self.container_name, blob_name)
blob.upload_blob(data, overwrite=True)

progress = ProgressTracker(len(data), 1024)

# Act
blob.download_blob(max_concurrency=3, progress_hook=progress.assert_progress).readall()

# Assert
progress.assert_complete()

@pytest.mark.live_test_only
@BlobPreparer()
def test_get_blob_progress_range(self, storage_account_name, storage_account_key):
# parallel tests introduce random order of requests, can only run live
self._setup(storage_account_name, storage_account_key)
data = b'a' * 5120
blob_name = self._get_blob_reference()
blob = self.bsc.get_blob_client(self.container_name, blob_name)
blob.upload_blob(data, overwrite=True)

length = 4096
progress = ProgressTracker(length, 1024)

# Act
blob.download_blob(
offset=512,
length=length,
max_concurrency=3,
progress_hook=progress.assert_progress
).readall()

# Assert
progress.assert_complete()

@pytest.mark.live_test_only
@BlobPreparer()
def test_get_blob_progress_readinto(self, storage_account_name, storage_account_key):
# parallel tests introduce random order of requests, can only run live
self._setup(storage_account_name, storage_account_key)
data = b'a' * 5120
blob_name = self._get_blob_reference()
blob = self.bsc.get_blob_client(self.container_name, blob_name)
blob.upload_blob(data, overwrite=True)

progress = ProgressTracker(len(data), 1024)
result = BytesIO()

# Act
stream = blob.download_blob(max_concurrency=3, progress_hook=progress.assert_progress)
read = stream.readinto(result)

# Assert
progress.assert_complete()
self.assertEqual(len(data), read)

# ------------------------------------------------------------------------------
Loading

0 comments on commit 6a83ffa

Please sign in to comment.