Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure-Core Exceptions: Type Complete #31056

Merged
merged 10 commits into from
Aug 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 107 additions & 42 deletions sdk/core/azure-core/azure/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,39 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------

from __future__ import annotations
import json
import logging
import sys

from typing import Callable, Any, Optional, Union, Type, List, Dict, TYPE_CHECKING
from types import TracebackType
from typing import (
Callable,
Any,
Optional,
Union,
Type,
List,
Mapping,
TypeVar,
Generic,
Dict,
TYPE_CHECKING,
)
from typing_extensions import Protocol, runtime_checkable

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from azure.core.pipeline.transport._base import _HttpResponseBase
from azure.core.pipeline.policies import RequestHistory

HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")
KeyType = TypeVar("KeyType")
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
ValueType = TypeVar("ValueType")
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
# To replace when typing.Self is available in our baseline
SelfODataV4Format = TypeVar("SelfODataV4Format", bound="ODataV4Format")


__all__ = [
"AzureError",
Expand All @@ -59,7 +79,7 @@
]


def raise_with_traceback(exception: Callable, *args, **kwargs) -> None:
def raise_with_traceback(exception: Callable, *args: Any, **kwargs: Any) -> None:
"""Raise exception with a specified traceback.
This MUST be called inside a "except" clause.

Expand All @@ -83,26 +103,58 @@ def raise_with_traceback(exception: Callable, *args, **kwargs) -> None:
raise error # pylint: disable=raise-missing-from


class ErrorMap:
@runtime_checkable
class _HttpResponseCommonAPI(Protocol):
"""Protocol used by exceptions for HTTP response.

As HttpResponseError uses very few properties of HttpResponse, a protocol
is faster and simpler than import all the possible types (at least 6).
"""

@property
def reason(self) -> Optional[str]:
pass

@property
def status_code(self) -> Optional[int]:
annatisch marked this conversation as resolved.
Show resolved Hide resolved
pass

def text(self) -> str:
pass

@property
def request(self) -> object: # object as type, since all we need is str() on it
pass


class ErrorMap(Generic[KeyType, ValueType]):
"""Error Map class. To be used in map_error method, behaves like a dictionary.
It returns the error type if it is found in custom_error_map. Or return default_error

:param dict custom_error_map: User-defined error map, it is used to map status codes to error types.
:keyword error default_error: Default error type. It is returned if the status code is not found in custom_error_map
"""

def __init__(self, custom_error_map=None, **kwargs):
def __init__(
self, # pylint: disable=unused-argument
custom_error_map: Optional[Mapping[KeyType, ValueType]] = None,
*,
default_error: Optional[ValueType] = None,
**kwargs: Any,
) -> None:
self._custom_error_map = custom_error_map or {}
self._default_error = kwargs.pop("default_error", None)
self._default_error = default_error

def get(self, key):
def get(self, key: KeyType) -> Optional[ValueType]:
ret = self._custom_error_map.get(key)
if ret:
return ret
return self._default_error


def map_error(status_code, response, error_map):
def map_error(
status_code: int, response: _HttpResponseCommonAPI, error_map: Mapping[int, Type[HttpResponseError]]
) -> None:
if not error_map:
return
error_type = error_map.get(status_code)
Expand Down Expand Up @@ -157,7 +209,7 @@ class ODataV4Format:
DETAILS_LABEL = "details"
INNERERROR_LABEL = "innererror"

def __init__(self, json_object: Dict[str, Any]):
def __init__(self, json_object: Mapping[str, Any]) -> None:
if "error" in json_object:
json_object = json_object["error"]
cls: Type[ODataV4Format] = self.__class__
Expand All @@ -180,10 +232,10 @@ def __init__(self, json_object: Dict[str, Any]):
except Exception: # pylint: disable=broad-except
pass

self.innererror: Dict[str, Any] = json_object.get(cls.INNERERROR_LABEL, {})
self.innererror: Mapping[str, Any] = json_object.get(cls.INNERERROR_LABEL, {})

@property
def error(self):
def error(self: SelfODataV4Format) -> SelfODataV4Format:
import warnings

warnings.warn(
Expand All @@ -192,7 +244,7 @@ def error(self):
)
return self

def __str__(self):
def __str__(self) -> str:
return "({}) {}\n{}".format(self.code, self.message, self.message_details())

def message_details(self) -> str:
Expand Down Expand Up @@ -220,7 +272,7 @@ def message_details(self) -> str:
class AzureError(Exception):
"""Base exception for all errors.

:param message: The message object stringified as 'message' attribute
:param object message: The message object stringified as 'message' attribute
:keyword error: The original exception if any
:paramtype error: Exception

Expand All @@ -235,16 +287,21 @@ class AzureError(Exception):
and will be `None` where continuation is either unavailable or not applicable.
"""

def __init__(self, message, *args, **kwargs):
self.inner_exception = kwargs.get("error")
self.exc_type, self.exc_value, self.exc_traceback = sys.exc_info()
self.exc_type = self.exc_type.__name__ if self.exc_type else type(self.inner_exception)
self.exc_msg = "{}, {}: {}".format(message, self.exc_type, self.exc_value)
self.message = str(message)
self.continuation_token = kwargs.get("continuation_token")
def __init__(self, message: Optional[object], *args: Any, **kwargs: Any) -> None:
annatisch marked this conversation as resolved.
Show resolved Hide resolved
self.inner_exception: Optional[BaseException] = kwargs.get("error")

exc_info = sys.exc_info()
self.exc_type: Optional[Type[Any]] = exc_info[0]
self.exc_value: Optional[BaseException] = exc_info[1]
self.exc_traceback: Optional[TracebackType] = exc_info[2]

self.exc_type = self.exc_type if self.exc_type else type(self.inner_exception)
self.exc_msg: str = "{}, {}: {}".format(message, self.exc_type.__name__, self.exc_value)
self.message: Optional[str] = str(message)
self.continuation_token: Optional[str] = kwargs.get("continuation_token")
super(AzureError, self).__init__(self.message, *args)

def raise_with_traceback(self):
def raise_with_traceback(self) -> None:
"""Raise the exception with the existing traceback.

.. deprecated:: 1.22.0
Expand All @@ -253,7 +310,7 @@ def raise_with_traceback(self):
try:
raise super(AzureError, self).with_traceback(self.exc_traceback) # pylint: disable=raise-missing-from
except AttributeError:
self.__traceback__ = self.exc_traceback
self.__traceback__: Optional[TracebackType] = self.exc_traceback
raise self # pylint: disable=raise-missing-from


Expand All @@ -280,8 +337,7 @@ class ServiceResponseTimeoutError(ServiceResponseError):
class HttpResponseError(AzureError):
"""A request was made, and a non-success status code was received from the service.

:param message: HttpResponse's error message
:type message: string
:param object message: The message object stringified as 'message' attribute
:param response: The response that triggered the exception.
:type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse

Expand All @@ -297,24 +353,27 @@ class HttpResponseError(AzureError):
:vartype error: ODataV4Format
"""

def __init__(self, message=None, response=None, **kwargs):
def __init__(
self, message: Optional[object] = None, response: Optional[_HttpResponseCommonAPI] = None, **kwargs: Any
) -> None:
# Don't want to document this one yet.
error_format = kwargs.get("error_format", ODataV4Format)

self.reason = None
self.status_code = None
self.response = response
self.reason: Optional[str] = None
self.status_code: Optional[int] = None
self.response: Optional[_HttpResponseCommonAPI] = response
if response:
self.reason = response.reason
self.status_code = response.status_code

# old autorest are setting "error" before calling __init__, so it might be there already
# transferring into self.model
model: Optional[Any] = kwargs.pop("model", None)
self.model: Optional[Any]
if model is not None: # autorest v5
self.model = model
else: # autorest azure-core, for KV 1.0, Storage 12.0, etc.
self.model: Optional[Any] = getattr(self, "error", None)
self.model = getattr(self, "error", None)
self.error: Optional[ODataV4Format] = self._parse_odata_body(error_format, response)

# By priority, message is:
Expand All @@ -329,19 +388,23 @@ def __init__(self, message=None, response=None, **kwargs):
super(HttpResponseError, self).__init__(message=message, **kwargs)

@staticmethod
def _parse_odata_body(error_format: Type[ODataV4Format], response: "_HttpResponseBase") -> Optional[ODataV4Format]:
def _parse_odata_body(
error_format: Type[ODataV4Format], response: Optional[_HttpResponseCommonAPI]
) -> Optional[ODataV4Format]:
try:
odata_json = json.loads(response.text())
# https://github.com/python/mypy/issues/14743#issuecomment-1664725053
odata_json = json.loads(response.text()) # type: ignore
return error_format(odata_json)
except Exception: # pylint: disable=broad-except
# If the body is not JSON valid, just stop now
pass
return None

def __str__(self):
def __str__(self) -> str:
retval = super(HttpResponseError, self).__str__()
try:
body = self.response.text()
# https://github.com/python/mypy/issues/14743#issuecomment-1664725053
body = self.response.text() # type: ignore
if body and not self.error:
return "{}\nContent: {}".format(retval, body)[:2048]
except Exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -381,14 +444,16 @@ class ResourceNotModifiedError(HttpResponseError):
This will not be raised directly by the Azure core pipeline."""


class TooManyRedirectsError(HttpResponseError):
class TooManyRedirectsError(HttpResponseError, Generic[HTTPRequestType, HTTPResponseType]):
"""Reached the maximum number of redirect attempts.

:param history: The history of requests made while trying to fulfill the request.
:type history: list[~azure.core.pipeline.policies.RequestHistory]
"""

def __init__(self, history, *args, **kwargs):
def __init__(
self, history: "List[RequestHistory[HTTPRequestType, HTTPResponseType]]", *args: Any, **kwargs: Any
) -> None:
self.history = history
message = "Reached maximum redirect attempts."
super(TooManyRedirectsError, self).__init__(message, *args, **kwargs)
Expand All @@ -414,7 +479,7 @@ class ODataV4Error(HttpResponseError):

_ERROR_FORMAT = ODataV4Format

def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
def __init__(self, response: _HttpResponseCommonAPI, **kwargs: Any) -> None:
# Ensure field are declared, whatever can happen afterwards
self.odata_json: Optional[Dict[str, Any]] = None
try:
Expand All @@ -428,7 +493,7 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
self.message: Optional[str] = kwargs.get("message", odata_message)
self.target: Optional[str] = None
self.details: Optional[List[Any]] = []
self.innererror: Optional[Dict[str, Any]] = {}
self.innererror: Optional[Mapping[str, Any]] = {}

if self.message and "message" not in kwargs:
kwargs["message"] = self.message
Expand All @@ -445,7 +510,7 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
_LOGGER.info("Received error message was not valid OdataV4 format.")
self._error_format = "JSON was invalid for format " + str(self._ERROR_FORMAT)

def __str__(self):
def __str__(self) -> str:
if self._error_format:
return str(self._error_format)
return super(ODataV4Error, self).__str__()
Expand All @@ -461,7 +526,7 @@ class StreamConsumedError(AzureError):
:type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
"""

def __init__(self, response):
def __init__(self, response: _HttpResponseCommonAPI) -> None:
message = (
"You are attempting to read or stream the content from request {}. "
"You have likely already consumed this stream, so it can not be accessed anymore.".format(response.request)
Expand All @@ -479,7 +544,7 @@ class StreamClosedError(AzureError):
:type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
"""

def __init__(self, response):
def __init__(self, response: _HttpResponseCommonAPI) -> None:
message = (
"The content for response from request {} can no longer be read or streamed, since the "
"response has already been closed.".format(response.request)
Expand All @@ -497,7 +562,7 @@ class ResponseNotReadError(AzureError):
:type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
"""

def __init__(self, response):
def __init__(self, response: _HttpResponseCommonAPI) -> None:
message = (
"You have not read in the bytes for the response from request {}. "
"Call .read() on the response first.".format(response.request)
Expand Down