Skip to content

Commit

Permalink
Mypy Compatibilty for EventGrid (#14344)
Browse files Browse the repository at this point in the history
* Mypy Compatibilyt for EventGrid

* Update sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py

* comments
  • Loading branch information
Rakshith Bhyravabhotla authored Oct 16, 2020
1 parent 3b315d2 commit d2441fc
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 32 deletions.
3 changes: 2 additions & 1 deletion eng/tox/mypy_hard_failure_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"azure-servicebus",
"azure-ai-textanalytics",
"azure-ai-formrecognizer",
"azure-ai-metricsadvisor"
"azure-ai-metricsadvisor",
"azure-eventgrid",
]
2 changes: 1 addition & 1 deletion sdk/eventgrid/azure-eventgrid/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
6 changes: 3 additions & 3 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING
import logging
from ._models import CloudEvent, EventGridEvent

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Union

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +58,7 @@ def decode_eventgrid_event(self, eventgrid_event, **kwargs): # pylint: disable=n
eventgrid_event = EventGridEvent._from_json(eventgrid_event, encode) # pylint: disable=protected-access
deserialized_event = EventGridEvent.deserialize(eventgrid_event)
EventGridEvent._deserialize_data(deserialized_event, deserialized_event.event_type) # pylint: disable=protected-access
return deserialized_event
return cast(EventGridEvent, deserialized_event)
except Exception as err:
_LOGGER.error('Error: cannot deserialize event. Event does not have a valid format. \
Event must be a string, dict, or bytes following the CloudEvent schema.')
Expand Down
8 changes: 6 additions & 2 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hashlib
import hmac
import base64
from typing import TYPE_CHECKING, Any
try:
from urllib.parse import quote
except ImportError:
Expand All @@ -16,8 +17,11 @@
from ._signature_credential_policy import EventGridSharedAccessSignatureCredentialPolicy
from . import _constants as constants

if TYPE_CHECKING:
from datetime import datetime

def generate_shared_access_signature(topic_hostname, shared_access_key, expiration_date_utc, **kwargs):
# type: (str, str, datetime.Datetime, Any) -> str
# type: (str, str, datetime, Any) -> str
""" Helper method to generate shared access signature given hostname, key, and expiration date.
:param str topic_hostname: The topic endpoint to send the events to.
Similar to <YOUR-TOPIC-NAME>.<YOUR-REGION-NAME>-1.eventgrid.azure.net
Expand Down Expand Up @@ -82,7 +86,7 @@ def _get_authentication_policy(credential):
return authentication_policy

def _is_cloud_event(event):
# type: dict -> bool
# type: (Any) -> bool
required = ('id', 'source', 'specversion', 'type')
try:
return all([_ in event for _ in required]) and event['specversion'] == "1.0"
Expand Down
2 changes: 2 additions & 0 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint:disable=protected-access
from typing import Union, Any, Dict
import datetime as dt
import uuid
import json
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, source, type, **kwargs): # pylint: disable=redefined-builtin

@classmethod
def _from_generated(cls, cloud_event, **kwargs):
# type: (Union[str, Dict, bytes], Any) -> CloudEvent
generated = InternalCloudEvent.deserialize(cloud_event)
if generated.additional_properties:
extensions = dict(generated.additional_properties)
Expand Down
3 changes: 3 additions & 0 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# license information.
# --------------------------------------------------------------------------
import json
from typing import TYPE_CHECKING
import logging
from azure.core.pipeline.policies import SansIOHTTPPolicy

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from azure.core.pipeline import PipelineRequest

class CloudEventDistributedTracingPolicy(SansIOHTTPPolicy):
"""CloudEventDistributedTracingPolicy is a policy which adds distributed tracing informatiom
Expand Down
31 changes: 22 additions & 9 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast, Dict, List, Any, Union

from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import (
Expand All @@ -27,10 +27,12 @@
from ._generated._event_grid_publisher_client import EventGridPublisherClient as EventGridPublisherClientImpl
from ._policies import CloudEventDistributedTracingPolicy
from ._version import VERSION
from ._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Union, Dict, List
from azure.core.credentials import AzureKeyCredential
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
SendType = Union[
CloudEvent,
EventGridEvent,
Expand All @@ -42,6 +44,13 @@
List[Dict]
]

ListEventType = Union[
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]


class EventGridPublisherClient(object):
"""EventGrid Python Publisher Client.
Expand Down Expand Up @@ -79,7 +88,7 @@ def _policies(credential, **kwargs):
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
CloudEventDistributedTracingPolicy(**kwargs),
CloudEventDistributedTracingPolicy(),
HttpLoggingPolicy(**kwargs)
]
return policies
Expand All @@ -98,20 +107,24 @@ def send(self, events, **kwargs):
:raises: :class:`ValueError`, when events do not follow specified SendType.
"""
if not isinstance(events, list):
events = [events]
events = cast(ListEventType, [events])

if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
try:
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
events = [cast(CloudEvent, e)._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
except AttributeError:
pass # means it's a dictionary
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
self._client.publish_cloud_event_events(
self._topic_hostname,
cast(List[InternalCloudEvent], events),
**kwargs
)
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
kwargs.setdefault("content_type", "application/json; charset=utf-8")
self._client.publish_events(self._topic_hostname, events, **kwargs)
self._client.publish_events(self._topic_hostname, cast(List[InternalEventGridEvent], events), **kwargs)
elif all(isinstance(e, CustomEvent) for e in events):
serialized_events = [dict(e) for e in events]
self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
serialized_events = [dict(e) for e in events] # type: ignore
self._client.publish_custom_event_events(self._topic_hostname, cast(List, serialized_events), **kwargs)
else:
raise ValueError("Event schema is not correct.")
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
# license information.
# -------------------------------------------------------------------------

from typing import Any, TYPE_CHECKING
import six

from azure.core.pipeline.policies import SansIOHTTPPolicy

if TYPE_CHECKING:
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential


class EventGridSharedAccessSignatureCredentialPolicy(SansIOHTTPPolicy):
"""Adds a token header for the provided credential.
:param credential: The credential used to authenticate requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import Any, Union, List, Dict
from typing import Any, Union, List, Dict, cast
from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import (
Expand All @@ -26,19 +26,27 @@
from .._models import CloudEvent, EventGridEvent, CustomEvent
from .._helpers import _get_topic_hostname_only_fqdn, _get_authentication_policy, _is_cloud_event
from .._generated.aio import EventGridPublisherClient as EventGridPublisherClientAsync
from .._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent
from .._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
from .._version import VERSION

SendType = Union[
CloudEvent,
EventGridEvent,
CustomEvent,
Dict,
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]
CloudEvent,
EventGridEvent,
CustomEvent,
Dict,
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]

ListEventType = Union[
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]

class EventGridPublisherClient():
"""Asynchronous EventGrid Python Publisher Client.
Expand Down Expand Up @@ -101,20 +109,34 @@ async def send(
:raises: :class:`ValueError`, when events do not follow specified SendType.
"""
if not isinstance(events, list):
events = [events]
events = cast(ListEventType, [events])

if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
try:
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
events = [
cast(CloudEvent, e)._to_generated(**kwargs) for e in events # pylint: disable=protected-access
]
except AttributeError:
pass # means it's a dictionary
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
await self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
await self._client.publish_cloud_event_events(
self._topic_hostname,
cast(List[InternalCloudEvent], events),
**kwargs
)
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
kwargs.setdefault("content_type", "application/json; charset=utf-8")
await self._client.publish_events(self._topic_hostname, events, **kwargs)
await self._client.publish_events(
self._topic_hostname,
cast(List[InternalEventGridEvent], events),
**kwargs
)
elif all(isinstance(e, CustomEvent) for e in events):
serialized_events = [dict(e) for e in events]
await self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
serialized_events = [dict(e) for e in events] # type: ignore
await self._client.publish_custom_event_events(
self._topic_hostname,
cast(List, serialized_events),
**kwargs
)
else:
raise ValueError("Event schema is not correct.")
13 changes: 13 additions & 0 deletions sdk/eventgrid/azure-eventgrid/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[mypy]
python_version = 3.7
warn_return_any = True
warn_unused_configs = True
ignore_missing_imports = True

# Per-module options:

[mypy-azure.eventgrid._generated.*]
ignore_errors = True

[mypy-azure.core.*]
ignore_errors = True

0 comments on commit d2441fc

Please sign in to comment.