Skip to content

Commit

Permalink
Types in poller (Azure#29228)
Browse files Browse the repository at this point in the history
* Types in poller

* Black it right

* Saving this week's work

* Move typing

* Split base polling in two

* Typing fixes

* Typing update

* Black

* Unecessary Generic

* Stringify types

* Fix import

* Spellcheck

* Weird typo...

* PyLint

* More types

* Update sdk/core/azure-core/azure/core/polling/async_base_polling.py

Co-authored-by: Kashif Khan <361477+kashifkhan@users.noreply.github.com>

* Missing type

* Typing of the day

* Re-enable verifytypes

* Simplify the expectations async pipeline has on the response

* Async Cxt Manager

* Final Typing?

* More covariant

* Upside down

* Fix tests

* Messed up merge

* Pylint

* Better Typing

* Final typing?

* Pylint

* Simplify translation typing for now

* Fix backcompat with azure-mgmt-core

* Revert renaming private methods

* Black

* Feedback from @kristapratico

* Docstrings part 1

* Polling pylint part 2

* Black

* All LRO impl should use TypeVar

* Feedback

* Convert some Anyu after feedback

* Spellcheck

* Black

* Update sdk/core/azure-core/azure/core/polling/_async_poller.py

* Update sdk/core/azure-core/azure/core/polling/_async_poller.py

* Update sdk/core/azure-core/azure/core/polling/_poller.py

---------

Co-authored-by: Kashif Khan <361477+kashifkhan@users.noreply.github.com>
  • Loading branch information
lmazuel and kashifkhan authored Jul 6, 2023
1 parent 715008b commit 8f06d2c
Show file tree
Hide file tree
Showing 15 changed files with 642 additions and 279 deletions.
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
"ints",
"iohttp",
"IOHTTP",
"IOLRO",
"inprogress",
"ipconfiguration",
"ipconfigurations",
Expand Down
22 changes: 3 additions & 19 deletions sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
Generic,
Optional,
cast,
TYPE_CHECKING,
)
from typing_extensions import Protocol
from .configuration import Configuration
from .pipeline import AsyncPipeline
from .pipeline.transport._base import PipelineClientBase
Expand All @@ -51,17 +49,8 @@
)


if TYPE_CHECKING: # Protocol and non-Protocol can't mix in Python 3.7

class _AsyncContextManagerCloseable(AsyncContextManager, Protocol):
"""Defines a context manager that is closeable at the same time."""

async def close(self):
...


HTTPRequestType = TypeVar("HTTPRequestType")
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="_AsyncContextManagerCloseable")
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="AsyncContextManager")

_LOGGER = logging.getLogger(__name__)

Expand All @@ -80,11 +69,9 @@ class _Coroutine(Awaitable[AsyncHTTPResponseType]):
This allows the dev to either use the "async with" syntax, or simply the object directly.
It's also why "send_request" is not declared as async, since it couldn't be both easily.
"wrapped" must be an awaitable that returns an object that:
- has an async "close()"
- has an "__aexit__" method (IOW, is an async context manager)
"wrapped" must be an awaitable object that returns an object implements the async context manager protocol.
This permits this code to work for both requests.
This permits this code to work for both following requests.
```python
from azure.core import AsyncPipelineClient
Expand Down Expand Up @@ -124,9 +111,6 @@ async def __aenter__(self) -> AsyncHTTPResponseType:
async def __aexit__(self, *args) -> None:
await self._response.__aexit__(*args)

async def close(self) -> None:
await self._response.close()


class AsyncPipelineClient(
PipelineClientBase,
Expand Down
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from typing import TypeVar, Generic, Dict, Any

HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")
HTTPResponseType = TypeVar("HTTPResponseType", covariant=True)
HTTPRequestType = TypeVar("HTTPRequestType", covariant=True)


class PipelineContext(Dict[str, Any]):
Expand Down
37 changes: 8 additions & 29 deletions sdk/core/azure-core/azure/core/pipeline/policies/_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,46 +35,25 @@
import types
import re
import uuid
from typing import IO, cast, Union, Optional, AnyStr, Dict, MutableMapping, Any, Set, Mapping
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Set, Mapping
import urllib.parse
from typing_extensions import Protocol

from azure.core import __version__ as azcore_version
from azure.core.exceptions import DecodeError

from azure.core.pipeline import PipelineRequest, PipelineResponse
from ._base import SansIOHTTPPolicy

from ..transport import HttpRequest as LegacyHttpRequest
from ..transport._base import _HttpResponseBase as LegacySansIOHttpResponse
from ...rest import HttpRequest
from ...rest._rest_py3 import _HttpResponseBase as SansIOHttpResponse

_LOGGER = logging.getLogger(__name__)


class HTTPRequestType(Protocol):
"""Protocol compatible with new rest request and legacy transport request"""

headers: MutableMapping[str, str]
url: str
method: str
body: bytes


class HTTPResponseType(Protocol):
"""Protocol compatible with new rest response and legacy transport response"""

@property
def headers(self) -> MutableMapping[str, str]:
...

@property
def status_code(self) -> int:
...

@property
def content_type(self) -> Optional[str]:
...

def text(self, encoding: Optional[str] = None) -> str:
...
HTTPRequestType = Union[LegacyHttpRequest, HttpRequest]
HTTPResponseType = Union[LegacySansIOHttpResponse, SansIOHttpResponse]
PipelineResponseType = PipelineResponse[HTTPRequestType, HTTPResponseType]


class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand Down
14 changes: 10 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/policies/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,34 @@
# --------------------------------------------------------------------------
import datetime
import email.utils
from typing import Optional, cast

from urllib.parse import urlparse
from ...utils._utils import _FixedOffset, case_insensitive_dict


def _parse_http_date(text):
def _parse_http_date(text: str) -> datetime.datetime:
"""Parse a HTTP date format into datetime.
:param str text: Text containing a date in HTTP format
:rtype: datetime.datetime
:return: The parsed datetime
"""
parsed_date = email.utils.parsedate_tz(text)
return datetime.datetime(*parsed_date[:6], tzinfo=_FixedOffset(parsed_date[9] / 60))
if not parsed_date:
raise ValueError("Invalid HTTP date")
tz_offset = cast(int, parsed_date[9]) # Look at the code, tz_offset is always an int, at worst 0
return datetime.datetime(*parsed_date[:6], tzinfo=_FixedOffset(tz_offset / 60))


def parse_retry_after(retry_after: str):
def parse_retry_after(retry_after: str) -> float:
"""Helper to parse Retry-After and get value in seconds.
:param str retry_after: Retry-After header
:rtype: float
:return: Value of Retry-After in seconds.
"""
delay: float # Using the Mypy recommendation to use float for "int or float"
try:
delay = int(retry_after)
except ValueError:
Expand All @@ -56,7 +62,7 @@ def parse_retry_after(retry_after: str):
return max(0, delay)


def get_retry_after(response):
def get_retry_after(response) -> Optional[float]:
"""Get the value of Retry-After in seconds.
:param response: The PipelineResponse object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _iterate_response_content(iterator):
raise _ResponseStopIteration() # pylint: disable=raise-missing-from


class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
class AsyncHttpResponse(_HttpResponseBase, AbstractAsyncContextManager): # pylint: disable=abstract-method
"""An AsyncHttpResponse ABC.
Allows for the asynchronous streaming of data from the response.
Expand Down Expand Up @@ -93,6 +93,9 @@ def parts(self) -> AsyncIterator:

return _PartGenerator(self, default_http_response_type=AsyncHttpClientTransportResponse)

async def __aexit__(self, exc_type, exc_value, traceback):
return None


class AsyncHttpClientTransportResponse( # pylint: disable=abstract-method
_HttpClientTransportResponse, AsyncHttpResponse
Expand Down
18 changes: 15 additions & 3 deletions sdk/core/azure-core/azure/core/polling/_async_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ async def run(self): # pylint:disable=invalid-overridden-method
"""


async def async_poller(client, initial_response, deserialization_callback, polling_method):
async def async_poller(
client: Any,
initial_response: Any,
deserialization_callback: Callable[[Any], PollingReturnType_co],
polling_method: AsyncPollingMethod[PollingReturnType_co],
) -> PollingReturnType_co:
"""Async Poller for long running operations.
.. deprecated:: 1.5.0
Expand All @@ -86,6 +91,8 @@ async def async_poller(client, initial_response, deserialization_callback, polli
:type deserialization_callback: callable or msrest.serialization.Model
:param polling_method: The polling strategy to adopt
:type polling_method: ~azure.core.polling.PollingMethod
:return: The final resource at the end of the polling.
:rtype: any or None
"""
poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method)
return await poller
Expand All @@ -109,7 +116,7 @@ def __init__(
self,
client: Any,
initial_response: Any,
deserialization_callback: Callable,
deserialization_callback: Callable[[Any], PollingReturnType_co],
polling_method: AsyncPollingMethod[PollingReturnType_co],
):
self._polling_method = polling_method
Expand All @@ -124,7 +131,11 @@ def __init__(
self._polling_method.initialize(client, initial_response, deserialization_callback)

def polling_method(self) -> AsyncPollingMethod[PollingReturnType_co]:
"""Return the polling method associated to this poller."""
"""Return the polling method associated to this poller.
:return: The polling method associated to this poller.
:rtype: ~azure.core.polling.AsyncPollingMethod
"""
return self._polling_method

def continuation_token(self) -> str:
Expand Down Expand Up @@ -158,6 +169,7 @@ async def result(self) -> PollingReturnType_co:
"""Return the result of the long running operation.
:returns: The deserialized resource of the long running operation, if one is available.
:rtype: any or None
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
"""
await self.wait()
Expand Down
46 changes: 29 additions & 17 deletions sdk/core/azure-core/azure/core/polling/_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,44 @@ def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any
raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__))


class NoPolling(PollingMethod):
class NoPolling(PollingMethod[PollingReturnType_co]):
"""An empty poller that returns the deserialized initial response."""

_deserialization_callback: Callable[[Any], PollingReturnType_co]
"""Deserialization callback passed during initialization"""

def __init__(self):
self._initial_response = None
self._deserialization_callback = None

def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable) -> None:
def initialize(
self,
_: Any,
initial_response: Any,
deserialization_callback: Callable[[Any], PollingReturnType_co],
) -> None:
self._initial_response = initial_response
self._deserialization_callback = deserialization_callback

def run(self) -> None:
"""Empty run, no polling."""

def status(self) -> str:
"""Return the current status as a string.
"""Return the current status.
:rtype: str
:return: The current status
"""
return "succeeded"

def finished(self) -> bool:
"""Is this polling finished?
:rtype: bool
:return: Whether this polling is finished
"""
return True

def resource(self) -> Any:
def resource(self) -> PollingReturnType_co:
return self._deserialization_callback(self._initial_response)

def get_continuation_token(self) -> str:
Expand All @@ -105,7 +114,7 @@ def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any
try:
deserialization_callback = kwargs["deserialization_callback"]
except KeyError:
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token")
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None
import pickle

initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec
Expand All @@ -130,7 +139,7 @@ def __init__(
self,
client: Any,
initial_response: Any,
deserialization_callback: Callable,
deserialization_callback: Callable[[Any], PollingReturnType_co],
polling_method: PollingMethod[PollingReturnType_co],
) -> None:
self._callbacks: List[Callable] = []
Expand All @@ -147,10 +156,11 @@ def __init__(

# Prepare thread execution
self._thread = None
self._done = None
self._done = threading.Event()
self._exception = None
if not self._polling_method.finished():
self._done = threading.Event()
if self._polling_method.finished():
self._done.set()
else:
self._thread = threading.Thread(
target=with_current_context(self._start),
name="LROPoller({})".format(uuid.uuid4()),
Expand All @@ -161,9 +171,6 @@ def __init__(
def _start(self):
"""Start the long running operation.
On completion, runs any callbacks.
:param callable update_cmd: The API request to check the status of
the operation.
"""
try:
self._polling_method.run()
Expand All @@ -189,7 +196,11 @@ def _start(self):
callbacks, self._callbacks = self._callbacks, []

def polling_method(self) -> PollingMethod[PollingReturnType_co]:
"""Return the polling method associated to this poller."""
"""Return the polling method associated to this poller.
:return: The polling method
:rtype: ~azure.core.polling.PollingMethod
"""
return self._polling_method

def continuation_token(self) -> str:
Expand Down Expand Up @@ -223,8 +234,9 @@ def result(self, timeout: Optional[float] = None) -> PollingReturnType_co:
"""Return the result of the long running operation, or
the result available after the specified timeout.
:returns: The deserialized resource of the long running operation,
if one is available.
:param float timeout: Period of time to wait before getting back control.
:returns: The deserialized resource of the long running operation, if one is available.
:rtype: any or None
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
"""
self.wait(timeout)
Expand Down Expand Up @@ -266,7 +278,7 @@ def add_done_callback(self, func: Callable) -> None:
argument, a completed LongRunningOperation.
"""
# Still use "_done" and not "done", since CBs are executed inside the thread.
if self._done is None or self._done.is_set():
if self._done.is_set():
func(self._polling_method)
# Let's add them still, for consistency (if you wish to access to it for some reasons)
self._callbacks.append(func)
Expand Down
Loading

0 comments on commit 8f06d2c

Please sign in to comment.