Skip to content

Commit

Permalink
Improve request ID for custom header name (#30771)
Browse files Browse the repository at this point in the history
* Add header parameter to request ID

* Tests

* Add RequestIdPolicy to config

* Add also to async

* ChangeLog

* Issue numer

* Black

* Linter

* Black and typing

* Black tests

* Some simple typing to help verifytypes score

* Typing

* More typing for verifytypes

* Adding SansIOPolicy

* Keep it simple in typing for now

* Timeout typing

* Typing again

* Revert any changes that touches polling

* Disabling verifytypes score until we talked about this check

* Black
  • Loading branch information
lmazuel authored Jun 17, 2023
1 parent cc885e7 commit 9efde13
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 59 deletions.
4 changes: 3 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Release History

## 1.27.2 (Unreleased)
## 1.28.0 (Unreleased)

### Features Added

- Add header name parameter to RequestIdPolicy #30772

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _build_pipeline( # pylint: disable=no-self-use

if policies is None: # [] is a valid policy list
policies = [
RequestIdPolicy(**kwargs),
config.request_id_policy or RequestIdPolicy(**kwargs),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _build_pipeline( # pylint: disable=no-self-use

if policies is None: # [] is a valid policy list
policies = [
RequestIdPolicy(**kwargs),
config.request_id_policy or RequestIdPolicy(**kwargs),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.27.2"
VERSION = "1.28.0"
31 changes: 20 additions & 11 deletions sdk/core/azure-core/azure/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Union, Optional
from typing import Union, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from .pipeline.policies import HTTPPolicy, AsyncHTTPPolicy, SansIOHTTPPolicy

class Configuration:
AnyPolicy = Union[HTTPPolicy, AsyncHTTPPolicy, SansIOHTTPPolicy]


class Configuration: # pylint: disable=too-many-instance-attributes
"""Provides the home for all of the configurable policies in the pipeline.
A new Configuration object provides no default policies and does not specify in what
Expand All @@ -46,6 +51,7 @@ class Configuration:
User-Agent header.
:ivar authentication_policy: Provides configuration parameters for adding a bearer token Authorization
header to requests.
:ivar request_id_policy: Provides configuration parameters for adding a request id to requests.
:keyword polling_interval: Polling interval while doing LRO operations, if Retry-After is not set.
.. admonition:: Example:
Expand All @@ -59,31 +65,34 @@ class Configuration:

def __init__(self, **kwargs):
# Headers (sent with every request)
self.headers_policy = None
self.headers_policy: "Optional[AnyPolicy]" = None

# Proxy settings (Currently used to configure transport, could be pipeline policy instead)
self.proxy_policy = None
self.proxy_policy: "Optional[AnyPolicy]" = None

# Redirect configuration
self.redirect_policy = None
self.redirect_policy: "Optional[AnyPolicy]" = None

# Retry configuration
self.retry_policy = None
self.retry_policy: "Optional[AnyPolicy]" = None

# Custom hook configuration
self.custom_hook_policy = None
self.custom_hook_policy: "Optional[AnyPolicy]" = None

# Logger configuration
self.logging_policy = None
self.logging_policy: "Optional[AnyPolicy]" = None

# Http logger configuration
self.http_logging_policy = None
self.http_logging_policy: "Optional[AnyPolicy]" = None

# User Agent configuration
self.user_agent_policy = None
self.user_agent_policy: "Optional[AnyPolicy]" = None

# Authentication configuration
self.authentication_policy = None
self.authentication_policy: "Optional[AnyPolicy]" = None

# Request ID policy
self.request_id_policy: "Optional[AnyPolicy]" = None

# Polling interval if no retry-after in polling calls results
self.polling_interval = kwargs.get("polling_interval", 30)
Expand Down
78 changes: 46 additions & 32 deletions sdk/core/azure-core/azure/core/pipeline/policies/_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import types
import re
import uuid
from typing import IO, cast, Union, Optional, AnyStr, Dict, MutableMapping
from typing import IO, cast, Union, Optional, AnyStr, Dict, MutableMapping, Any, Set, Mapping
import urllib.parse
from typing_extensions import Protocol

Expand Down Expand Up @@ -97,13 +97,13 @@ class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
"""

def __init__(
self, base_headers: Optional[Dict[str, str]] = None, **kwargs
self, base_headers: Optional[Dict[str, str]] = None, **kwargs: Any
) -> None: # pylint: disable=super-init-not-called
self._headers: Dict[str, str] = base_headers or {}
self._headers.update(kwargs.pop("headers", {}))

@property
def headers(self):
def headers(self) -> Dict[str, str]:
"""The current headers collection."""
return self._headers

Expand Down Expand Up @@ -140,6 +140,7 @@ class RequestIdPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
:keyword str request_id: The request id to be added into header.
:keyword bool auto_request_id: Auto generates a unique request ID per call if true which is by default.
:keyword str request_id_header_name: Header name to use. Default is "x-ms-client-request-id".
.. admonition:: Example:
Expand All @@ -151,9 +152,18 @@ class RequestIdPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
:caption: Configuring a request id policy.
"""

def __init__(self, **kwargs) -> None: # pylint: disable=super-init-not-called
self._request_id = kwargs.pop("request_id", _Unset)
self._auto_request_id = kwargs.pop("auto_request_id", True)
def __init__(
self, # pylint: disable=unused-argument
*,
request_id: Union[str, Any] = _Unset,
auto_request_id: bool = True,
request_id_header_name: str = "x-ms-client-request-id",
**kwargs: Any
) -> None:
super()
self._request_id = request_id
self._auto_request_id = auto_request_id
self._request_id_header_name = request_id_header_name

def set_request_id(self, value: str) -> None:
"""Add the request id to the configuration to be applied to all requests.
Expand All @@ -176,15 +186,15 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
elif self._request_id is None:
return
elif self._request_id is not _Unset:
if "x-ms-client-request-id" in request.http_request.headers:
if self._request_id_header_name in request.http_request.headers:
return
request_id = self._request_id
elif self._auto_request_id:
if "x-ms-client-request-id" in request.http_request.headers:
if self._request_id_header_name in request.http_request.headers:
return
request_id = str(uuid.uuid1())
if request_id is not unset:
header = {"x-ms-client-request-id": cast(str, request_id)}
header = {self._request_id_header_name: cast(str, request_id)}
request.http_request.headers.update(header)


Expand Down Expand Up @@ -213,12 +223,12 @@ class UserAgentPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
_ENV_ADDITIONAL_USER_AGENT = "AZURE_HTTP_USER_AGENT"

def __init__(
self, base_user_agent: Optional[str] = None, **kwargs
self, base_user_agent: Optional[str] = None, **kwargs: Any
) -> None: # pylint: disable=super-init-not-called
self.overwrite = kwargs.pop("user_agent_overwrite", False)
self.use_env = kwargs.pop("user_agent_use_env", True)
application_id = kwargs.pop("user_agent", None)
sdk_moniker = kwargs.pop("sdk_moniker", "core/{}".format(azcore_version))
self.overwrite: bool = kwargs.pop("user_agent_overwrite", False)
self.use_env: bool = kwargs.pop("user_agent_use_env", True)
application_id: Optional[str] = kwargs.pop("user_agent", None)
sdk_moniker: str = kwargs.pop("sdk_moniker", "core/{}".format(azcore_version))

if base_user_agent:
self._user_agent = base_user_agent
Expand Down Expand Up @@ -283,7 +293,7 @@ class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseTy
:caption: Configuring a network trace logging policy.
"""

def __init__(self, logging_enable=False, **kwargs): # pylint: disable=unused-argument
def __init__(self, logging_enable: bool = False, **kwargs: Any): # pylint: disable=unused-argument
self.enable_http_logger = logging_enable

def on_request(
Expand Down Expand Up @@ -382,11 +392,11 @@ class _HiddenClassProperties(type):
# https://github.com/Azure/azure-sdk-for-python/issues/26331

@property
def DEFAULT_HEADERS_WHITELIST(cls):
def DEFAULT_HEADERS_WHITELIST(cls) -> Set[str]:
return cls.DEFAULT_HEADERS_ALLOWLIST

@DEFAULT_HEADERS_WHITELIST.setter
def DEFAULT_HEADERS_WHITELIST(cls, value):
def DEFAULT_HEADERS_WHITELIST(cls, value: Set[str]):
cls.DEFAULT_HEADERS_ALLOWLIST = value


Expand All @@ -396,7 +406,7 @@ class HttpLoggingPolicy(
):
"""The Pipeline policy that handles logging of HTTP requests and responses."""

DEFAULT_HEADERS_ALLOWLIST = set(
DEFAULT_HEADERS_ALLOWLIST: Set[str] = set(
[
"x-ms-request-id",
"x-ms-client-request-id",
Expand Down Expand Up @@ -425,19 +435,19 @@ class HttpLoggingPolicy(
"WWW-Authenticate", # OAuth Challenge header.
]
)
REDACTED_PLACEHOLDER = "REDACTED"
MULTI_RECORD_LOG = "AZURE_SDK_LOGGING_MULTIRECORD"
REDACTED_PLACEHOLDER: str = "REDACTED"
MULTI_RECORD_LOG: str = "AZURE_SDK_LOGGING_MULTIRECORD"

def __init__(self, logger=None, **kwargs): # pylint: disable=unused-argument
self.logger = logger or logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
self.allowed_query_params = set()
self.allowed_header_names = set(self.__class__.DEFAULT_HEADERS_ALLOWLIST)
def __init__(self, logger: Optional[logging.Logger] = None, **kwargs: Any): # pylint: disable=unused-argument
self.logger: logging.Logger = logger or logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
self.allowed_query_params: Set[str] = set()
self.allowed_header_names: Set[str] = set(self.__class__.DEFAULT_HEADERS_ALLOWLIST)

def _redact_query_param(self, key, value):
def _redact_query_param(self, key: str, value: str) -> str:
lower_case_allowed_query_params = [param.lower() for param in self.allowed_query_params]
return value if key.lower() in lower_case_allowed_query_params else HttpLoggingPolicy.REDACTED_PLACEHOLDER

def _redact_header(self, key, value):
def _redact_header(self, key: str, value: str) -> str:
lower_case_allowed_header_names = [header.lower() for header in self.allowed_header_names]
return value if key.lower() in lower_case_allowed_header_names else HttpLoggingPolicy.REDACTED_PLACEHOLDER

Expand Down Expand Up @@ -564,7 +574,9 @@ class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
# Name used in context
CONTEXT_NAME = "deserialized_data"

def __init__(self, response_encoding: Optional[str] = None, **kwargs) -> None: # pylint: disable=unused-argument
def __init__(
self, response_encoding: Optional[str] = None, **kwargs: Any # pylint: disable=unused-argument
) -> None:
self._response_encoding = response_encoding

@classmethod
Expand All @@ -573,7 +585,7 @@ def deserialize_from_text(
data: Optional[Union[AnyStr, IO]],
mime_type: Optional[str] = None,
response: Optional[HTTPResponseType] = None,
):
) -> Any:
"""Decode response data according to content-type.
Accept a stream of data as well, but will be load at once in memory for now.
Expand All @@ -584,7 +596,7 @@ def deserialize_from_text(
:param str mime_type: The mime type. As mime type, charset is not expected.
:param response: If passed, exception will be annotated with that response
:raises ~azure.core.exceptions.DecodeError: If deserialization fails
:returns: A dict or XML tree, depending of the mime_type
:returns: A dict (JSON), XML tree or str, depending of the mime_type
"""
if not data:
return None
Expand Down Expand Up @@ -643,15 +655,15 @@ def deserialize_from_http_generics(
cls,
response: HTTPResponseType,
encoding: Optional[str] = None,
):
) -> Any:
"""Deserialize from HTTP response.
Headers will tested for "content-type"
:param response: The HTTP response
:param encoding: The encoding to use if known for this service (will disable auto-detection)
:raises ~azure.core.exceptions.DecodeError: If deserialization fails
:returns: A dict or XML tree, depending of the mime-type
:returns: A dict (JSON), XML tree or str, depending of the mime_type
"""
# Try to use content-type from headers if available
if response.content_type:
Expand Down Expand Up @@ -733,7 +745,9 @@ class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
:caption: Configuring a proxy policy.
"""

def __init__(self, proxies=None, **kwargs): # pylint: disable=unused-argument,super-init-not-called
def __init__(
self, proxies: Optional[Mapping[str, str]] = None, **kwargs: Any
): # pylint: disable=unused-argument,super-init-not-called
self.proxies = proxies

def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.azure-sdk-build]
mypy = true
type_check_samples = true
verifytypes = true
verifytypes = false
pyright = false
# For test environments or static checks where a check should be run by default, not explicitly disabling will enable the check.
# pylint is enabled by default, so there is no reason for a pylint = true in every pyproject.toml.
Expand Down
36 changes: 25 additions & 11 deletions sdk/core/azure-core/tests/test_request_id_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@
request_id_init_values = ("foo", None, "_unset")
request_id_set_values = ("bar", None, "_unset")
request_id_req_values = ("baz", None, "_unset")
request_id_header_name_values = ("client-request-id", "_unset")
full_combination = list(
product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values, HTTP_REQUESTS)
product(
auto_request_id_values,
request_id_init_values,
request_id_set_values,
request_id_req_values,
request_id_header_name_values,
HTTP_REQUESTS,
)
)


@pytest.mark.parametrize(
"auto_request_id, request_id_init, request_id_set, request_id_req, http_request", full_combination
"auto_request_id, request_id_init, request_id_set, request_id_req, request_id_header_name, http_request",
full_combination,
)
def test_request_id_policy(auto_request_id, request_id_init, request_id_set, request_id_req, http_request):
def test_request_id_policy(
auto_request_id, request_id_init, request_id_set, request_id_req, request_id_header_name, http_request
):
"""Test policy with no other policy and happy path"""
kwargs = {}
if request_id_header_name != "_unset":
kwargs["request_id_header_name"] = request_id_header_name
if auto_request_id is not None:
kwargs["auto_request_id"] = auto_request_id
if request_id_init != "_unset":
Expand All @@ -44,22 +57,23 @@ def test_request_id_policy(auto_request_id, request_id_init, request_id_set, req
request_id_policy.on_request(pipeline_request)

assert all(v is not None for v in request.headers.values())
expected_header_name = "x-ms-client-request-id" if request_id_header_name == "_unset" else "client-request-id"
if request_id_req != "_unset" and request_id_req:
assert request.headers["x-ms-client-request-id"] == request_id_req
assert request.headers[expected_header_name] == request_id_req
elif not request_id_req:
assert not "x-ms-client-request-id" in request.headers
assert not expected_header_name in request.headers
elif request_id_set != "_unset" and request_id_set:
assert request.headers["x-ms-client-request-id"] == request_id_set
assert request.headers[expected_header_name] == request_id_set
elif not request_id_set:
assert not "x-ms-client-request-id" in request.headers
assert not expected_header_name in request.headers
elif request_id_init != "_unset" and request_id_init:
assert request.headers["x-ms-client-request-id"] == request_id_init
assert request.headers[expected_header_name] == request_id_init
elif not request_id_init:
assert not "x-ms-client-request-id" in request.headers
assert not expected_header_name in request.headers
elif auto_request_id or auto_request_id is None:
assert request.headers["x-ms-client-request-id"] == "VALUE"
assert request.headers[expected_header_name] == "VALUE"
else:
assert not "x-ms-client-request-id" in request.headers
assert not expected_header_name in request.headers


@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
Expand Down

0 comments on commit 9efde13

Please sign in to comment.