Skip to content

Commit

Permalink
[EG] Archboard Feedback (#35738)
Browse files Browse the repository at this point in the history
* regen

* remove all samples/tests before fixing

* move all topic/sub to client level

* update

* updates

* update samples

* add other publisher tests

* missing

* content type

* consumerclient

* upload consumer tests

* updates

* update

* changes

* updates

* rename

* update

* patch

* test update

* update tests

* fix

* updates snippets

* update readme

* try updating api_version

* typo

Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com>

* typo2

Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com>

* renames/docs from comments

* regen

* update patch

* remove import

* caps

---------

Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com>
  • Loading branch information
l0lawrence and swathipil authored May 30, 2024
1 parent e6b05cb commit 6186af3
Show file tree
Hide file tree
Showing 82 changed files with 1,860 additions and 3,395 deletions.
16 changes: 8 additions & 8 deletions sdk/eventgrid/azure-eventgrid/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ For example, you can use `DefaultAzureCredential` to construct a client which wi

```python
from azure.identity import DefaultAzureCredential
from azure.eventgrid import EventGridClient, EventGridEvent
from azure.eventgrid import EventGridPublisherClient, EventGridEvent

default_az_credential = DefaultAzureCredential()
endpoint = os.environ["EVENTGRID_ENDPOINT"]
client = EventGridClient(endpoint, default_az_credential)
endpoint = os.environ["EVENTGRID_TOPIC_ENDPOINT"]
client = EventGridPublisherClient(endpoint, default_az_credential)
```

<!-- END SNIPPET -->
Expand All @@ -105,14 +105,14 @@ pass the key as a string into an instance of [AzureKeyCredential][azure-key-cred

```python
import os
from azure.eventgrid import EventGridClient
from azure.eventgrid import EventGridPublisherClient
from azure.core.credentials import AzureKeyCredential

key = os.environ["EVENTGRID_KEY"]
endpoint = os.environ["EVENTGRID_ENDPOINT"]
topic_key = os.environ["EVENTGRID_TOPIC_KEY"]
endpoint = os.environ["EVENTGRID_TOPIC_ENDPOINT"]

credential_key = AzureKeyCredential(key)
client = EventGridClient(endpoint, credential_key)
credential_key = AzureKeyCredential(topic_key)
client = EventGridPublisherClient(endpoint, credential_key)
```

<!-- END SNIPPET -->
Expand Down
6 changes: 4 additions & 2 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._patch import EventGridClient
from ._patch import EventGridPublisherClient
from ._patch import EventGridConsumerClient
from ._version import VERSION

__version__ = VERSION
Expand All @@ -19,7 +20,8 @@
from ._patch import patch_sdk as _patch_sdk

__all__ = [
"EventGridClient",
"EventGridPublisherClient",
"EventGridConsumerClient",
]
__all__.extend([p for p in _patch_all if p not in __all__])

Expand Down
99 changes: 90 additions & 9 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
from azure.core.pipeline import policies
from azure.core.rest import HttpRequest, HttpResponse

from ._configuration import EventGridClientConfiguration
from ._operations import EventGridClientOperationsMixin
from ._configuration import EventGridConsumerClientConfiguration, EventGridPublisherClientConfiguration
from ._operations import EventGridConsumerClientOperationsMixin, EventGridPublisherClientOperationsMixin
from ._serialization import Deserializer, Serializer

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from azure.core.credentials import TokenCredential


class EventGridClient(EventGridClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword
"""Azure Messaging EventGrid Client.
class EventGridPublisherClient(
EventGridPublisherClientOperationsMixin
): # pylint: disable=client-accepts-api-version-keyword
"""EventGridPublisherClient.
:param endpoint: The host name of the namespace, e.g.
namespaceName1.westus-1.eventgrid.azure.net. Required.
Expand All @@ -33,15 +35,14 @@ class EventGridClient(EventGridClientOperationsMixin): # pylint: disable=client
AzureKeyCredential type or a TokenCredential type. Required.
:type credential: ~azure.core.credentials.AzureKeyCredential or
~azure.core.credentials.TokenCredential
:keyword api_version: The API version to use for this operation. Default value is
"2023-10-01-preview". Note that overriding this default value may result in unsupported
behavior.
:keyword api_version: The API version to use for this operation. Default value is "2024-06-01".
Note that overriding this default value may result in unsupported behavior.
:paramtype api_version: str
"""

def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None:
_endpoint = "{endpoint}"
self._config = EventGridClientConfiguration(endpoint=endpoint, credential=credential, **kwargs)
self._config = EventGridPublisherClientConfiguration(endpoint=endpoint, credential=credential, **kwargs)
_policies = kwargs.pop("policies", None)
if _policies is None:
_policies = [
Expand Down Expand Up @@ -94,7 +95,87 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "EventGridClient":
def __enter__(self) -> "EventGridPublisherClient":
self._client.__enter__()
return self

def __exit__(self, *exc_details: Any) -> None:
self._client.__exit__(*exc_details)


class EventGridConsumerClient(
EventGridConsumerClientOperationsMixin
): # pylint: disable=client-accepts-api-version-keyword
"""EventGridConsumerClient.
:param endpoint: The host name of the namespace, e.g.
namespaceName1.westus-1.eventgrid.azure.net. Required.
:type endpoint: str
:param credential: Credential used to authenticate requests to the service. Is either a
AzureKeyCredential type or a TokenCredential type. Required.
:type credential: ~azure.core.credentials.AzureKeyCredential or
~azure.core.credentials.TokenCredential
:keyword api_version: The API version to use for this operation. Default value is "2024-06-01".
Note that overriding this default value may result in unsupported behavior.
:paramtype api_version: str
"""

def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None:
_endpoint = "{endpoint}"
self._config = EventGridConsumerClientConfiguration(endpoint=endpoint, credential=credential, **kwargs)
_policies = kwargs.pop("policies", None)
if _policies is None:
_policies = [
policies.RequestIdPolicy(**kwargs),
self._config.headers_policy,
self._config.user_agent_policy,
self._config.proxy_policy,
policies.ContentDecodePolicy(**kwargs),
self._config.redirect_policy,
self._config.retry_policy,
self._config.authentication_policy,
self._config.custom_hook_policy,
self._config.logging_policy,
policies.DistributedTracingPolicy(**kwargs),
policies.SensitiveHeaderCleanupPolicy(**kwargs) if self._config.redirect_policy else None,
self._config.http_logging_policy,
]
self._client: PipelineClient = PipelineClient(base_url=_endpoint, policies=_policies, **kwargs)

self._serialize = Serializer()
self._deserialize = Deserializer()
self._serialize.client_side_validation = False

def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse:
"""Runs the network request through the client's chained policies.
>>> from azure.core.rest import HttpRequest
>>> request = HttpRequest("GET", "https://www.example.org/")
<HttpRequest [GET], url: 'https://www.example.org/'>
>>> response = client.send_request(request)
<HttpResponse: 200 OK>
For more information on this code flow, see https://aka.ms/azsdk/dpcodegen/python/send_request
:param request: The network request you want to make. Required.
:type request: ~azure.core.rest.HttpRequest
:keyword bool stream: Whether the response payload will be streamed. Defaults to False.
:return: The response of your network call. Does not do error handling on your response.
:rtype: ~azure.core.rest.HttpResponse
"""

request_copy = deepcopy(request)
path_format_arguments = {
"endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True),
}

request_copy.url = self._client.format_url(request_copy.url, **path_format_arguments)
return self._client.send_request(request_copy, stream=stream, **kwargs) # type: ignore

def close(self) -> None:
self._client.close()

def __enter__(self) -> "EventGridConsumerClient":
self._client.__enter__()
return self

Expand Down
68 changes: 62 additions & 6 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from azure.core.credentials import TokenCredential


class EventGridClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long
"""Configuration for EventGridClient.
class EventGridPublisherClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long
"""Configuration for EventGridPublisherClient.
Note that all parameters used to create this instance are saved as instance
attributes.
Expand All @@ -31,14 +31,70 @@ class EventGridClientConfiguration: # pylint: disable=too-many-instance-attribu
AzureKeyCredential type or a TokenCredential type. Required.
:type credential: ~azure.core.credentials.AzureKeyCredential or
~azure.core.credentials.TokenCredential
:keyword api_version: The API version to use for this operation. Default value is
"2023-10-01-preview". Note that overriding this default value may result in unsupported
behavior.
:keyword api_version: The API version to use for this operation. Default value is "2024-06-01".
Note that overriding this default value may result in unsupported behavior.
:paramtype api_version: str
"""

def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None:
api_version: str = kwargs.pop("api_version", "2023-10-01-preview")
api_version: str = kwargs.pop("api_version", "2024-06-01")

if endpoint is None:
raise ValueError("Parameter 'endpoint' must not be None.")
if credential is None:
raise ValueError("Parameter 'credential' must not be None.")

self.endpoint = endpoint
self.credential = credential
self.api_version = api_version
self.credential_scopes = kwargs.pop("credential_scopes", ["https://eventgrid.azure.net/.default"])
kwargs.setdefault("sdk_moniker", "eventgrid/{}".format(VERSION))
self.polling_interval = kwargs.get("polling_interval", 30)
self._configure(**kwargs)

def _infer_policy(self, **kwargs):
if isinstance(self.credential, AzureKeyCredential):
return policies.AzureKeyCredentialPolicy(
self.credential, "Authorization", prefix="SharedAccessKey", **kwargs
)
if hasattr(self.credential, "get_token"):
return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs)
raise TypeError(f"Unsupported credential: {self.credential}")

def _configure(self, **kwargs: Any) -> None:
self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs)
self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs)
self.proxy_policy = kwargs.get("proxy_policy") or policies.ProxyPolicy(**kwargs)
self.logging_policy = kwargs.get("logging_policy") or policies.NetworkTraceLoggingPolicy(**kwargs)
self.http_logging_policy = kwargs.get("http_logging_policy") or policies.HttpLoggingPolicy(**kwargs)
self.custom_hook_policy = kwargs.get("custom_hook_policy") or policies.CustomHookPolicy(**kwargs)
self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs)
self.retry_policy = kwargs.get("retry_policy") or policies.RetryPolicy(**kwargs)
self.authentication_policy = kwargs.get("authentication_policy")
if self.credential and not self.authentication_policy:
self.authentication_policy = self._infer_policy(**kwargs)


class EventGridConsumerClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long
"""Configuration for EventGridConsumerClient.
Note that all parameters used to create this instance are saved as instance
attributes.
:param endpoint: The host name of the namespace, e.g.
namespaceName1.westus-1.eventgrid.azure.net. Required.
:type endpoint: str
:param credential: Credential used to authenticate requests to the service. Is either a
AzureKeyCredential type or a TokenCredential type. Required.
:type credential: ~azure.core.credentials.AzureKeyCredential or
~azure.core.credentials.TokenCredential
:keyword api_version: The API version to use for this operation. Default value is "2024-06-01".
Note that overriding this default value may result in unsupported behavior.
:paramtype api_version: str
"""

def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any) -> None:
api_version: str = kwargs.pop("api_version", "2024-06-01")

if endpoint is None:
raise ValueError("Parameter 'endpoint' must not be None.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _is_cloud_event(event):
return False


def _is_eventgrid_event(event):
def _is_eventgrid_event_format(event):
# type: (Any) -> bool
required = ("subject", "eventType", "data", "dataVersion", "id", "eventTime")
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._helpers import (
_get_authentication_policy,
_is_cloud_event,
_is_eventgrid_event,
_is_eventgrid_event_format,
_eventgrid_data_typecheck,
_build_request,
_cloud_event_to_generated,
Expand Down Expand Up @@ -217,7 +217,7 @@ def send(self, events: SendType, *, channel_name: Optional[str] = None, **kwargs
## this is either a dictionary or a CNCF cloud event
events = [_from_cncf_events(e) for e in events]
content_type = "application/cloudevents-batch+json; charset=utf-8"
elif isinstance(events[0], EventGridEvent) or _is_eventgrid_event(events[0]):
elif isinstance(events[0], EventGridEvent) or _is_eventgrid_event_format(events[0]):
for event in events:
_eventgrid_data_typecheck(event)
response = self._client.send_request( # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .._models import EventGridEvent
from .._helpers import (
_is_cloud_event,
_is_eventgrid_event,
_is_eventgrid_event_format,
_eventgrid_data_typecheck,
_build_request,
_cloud_event_to_generated,
Expand Down Expand Up @@ -212,7 +212,7 @@ async def send(self, events: SendType, *, channel_name: Optional[str] = None, **
## this is either a dictionary or a CNCF cloud event
events = [_from_cncf_events(e) for e in events]
content_type = "application/cloudevents-batch+json; charset=utf-8"
elif isinstance(events[0], EventGridEvent) or _is_eventgrid_event(events[0]):
elif isinstance(events[0], EventGridEvent) or _is_eventgrid_event_format(events[0]):
for event in events:
_eventgrid_data_typecheck(event)
response = await self._client.send_request( # pylint: disable=protected-access
Expand Down
28 changes: 20 additions & 8 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# --------------------------------------------------------------------------
# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except

import copy
import calendar
import decimal
import functools
Expand Down Expand Up @@ -639,6 +640,13 @@ def _deserialize_sequence(
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)


def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.Any]:
return sorted(
types,
key=lambda x: hasattr(x, "__name__") and x.__name__.lower() in ("str", "float", "int", "bool"),
)


def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
annotation: typing.Any,
module: typing.Optional[str],
Expand Down Expand Up @@ -680,21 +688,25 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
# is it optional?
try:
if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
)

return functools.partial(_deserialize_with_optional, if_obj_deserializer)
if len(annotation.__args__) <= 2: # pyright: ignore
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
)

return functools.partial(_deserialize_with_optional, if_obj_deserializer)
# the type is Optional[Union[...]], we need to remove the None type from the Union
annotation_copy = copy.copy(annotation)
annotation_copy.__args__ = [a for a in annotation_copy.__args__ if a != type(None)] # pyright: ignore
return _get_deserialize_callable_from_annotation(annotation_copy, module, rf)
except AttributeError:
pass

# is it union?
if getattr(annotation, "__origin__", None) is typing.Union:
# initial ordering is we make `string` the last deserialization option, because it is often them most generic
deserializers = [
_get_deserialize_callable_from_annotation(arg, module, rf)
for arg in sorted(
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
)
for arg in _sorted_annotations(annotation.__args__) # pyright: ignore
]

return functools.partial(_deserialize_with_union, deserializers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._patch import EventGridClientOperationsMixin
from ._patch import EventGridPublisherClientOperationsMixin
from ._patch import EventGridConsumerClientOperationsMixin

from ._patch import __all__ as _patch_all
from ._patch import * # pylint: disable=unused-wildcard-import
from ._patch import patch_sdk as _patch_sdk

__all__ = [
"EventGridClientOperationsMixin",
"EventGridPublisherClientOperationsMixin",
"EventGridConsumerClientOperationsMixin",
]
__all__.extend([p for p in _patch_all if p not in __all__])
_patch_sdk()
Loading

0 comments on commit 6186af3

Please sign in to comment.