Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy Compatibilty for EventGrid #14344

Merged
merged 3 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion eng/tox/mypy_hard_failure_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"azure-servicebus",
"azure-ai-textanalytics",
"azure-ai-formrecognizer",
"azure-ai-metricsadvisor"
"azure-ai-metricsadvisor",
"azure-eventgrid",
]
2 changes: 1 addition & 1 deletion sdk/eventgrid/azure-eventgrid/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
6 changes: 3 additions & 3 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING
import logging
from ._models import CloudEvent, EventGridEvent

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Union

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +58,7 @@ def decode_eventgrid_event(self, eventgrid_event, **kwargs): # pylint: disable=n
eventgrid_event = EventGridEvent._from_json(eventgrid_event, encode) # pylint: disable=protected-access
deserialized_event = EventGridEvent.deserialize(eventgrid_event)
EventGridEvent._deserialize_data(deserialized_event, deserialized_event.event_type) # pylint: disable=protected-access
return deserialized_event
return cast(EventGridEvent, deserialized_event)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary unlike above - we get

Returning Any from function declared to return "EventGridEvent"

because the msrest's deserialize method on line 59 isn't typed

Copy link
Member

@lmazuel lmazuel Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's fix msrest too :)
Azure/msrest-for-python#226

But ok here, since we should not wait for msrest fix

except Exception as err:
_LOGGER.error('Error: cannot deserialize event. Event does not have a valid format. \
Event must be a string, dict, or bytes following the CloudEvent schema.')
Expand Down
8 changes: 6 additions & 2 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hashlib
import hmac
import base64
from typing import TYPE_CHECKING, Any
try:
from urllib.parse import quote
except ImportError:
Expand All @@ -16,8 +17,11 @@
from ._signature_credential_policy import EventGridSharedAccessSignatureCredentialPolicy
from . import _constants as constants

if TYPE_CHECKING:
from datetime import datetime

def generate_shared_access_signature(topic_hostname, shared_access_key, expiration_date_utc, **kwargs):
# type: (str, str, datetime.Datetime, Any) -> str
# type: (str, str, datetime, Any) -> str
""" Helper method to generate shared access signature given hostname, key, and expiration date.
:param str topic_hostname: The topic endpoint to send the events to.
Similar to <YOUR-TOPIC-NAME>.<YOUR-REGION-NAME>-1.eventgrid.azure.net
Expand Down Expand Up @@ -82,7 +86,7 @@ def _get_authentication_policy(credential):
return authentication_policy

def _is_cloud_event(event):
# type: dict -> bool
# type: (Any) -> bool
required = ('id', 'source', 'specversion', 'type')
try:
return all([_ in event for _ in required]) and event['specversion'] == "1.0"
Expand Down
2 changes: 2 additions & 0 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint:disable=protected-access
from typing import Union, Any, Dict
import datetime as dt
import uuid
import json
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, source, type, **kwargs): # pylint: disable=redefined-builtin

@classmethod
def _from_generated(cls, cloud_event, **kwargs):
# type: (Union[str, Dict, bytes], Any) -> CloudEvent
generated = InternalCloudEvent.deserialize(cloud_event)
if generated.additional_properties:
extensions = dict(generated.additional_properties)
Expand Down
3 changes: 3 additions & 0 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# license information.
# --------------------------------------------------------------------------
import json
from typing import TYPE_CHECKING
import logging
from azure.core.pipeline.policies import SansIOHTTPPolicy

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from azure.core.pipeline import PipelineRequest

class CloudEventDistributedTracingPolicy(SansIOHTTPPolicy):
"""CloudEventDistributedTracingPolicy is a policy which adds distributed tracing informatiom
Expand Down
31 changes: 22 additions & 9 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast, Dict, List, Any, Union

from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import (
Expand All @@ -27,10 +27,12 @@
from ._generated._event_grid_publisher_client import EventGridPublisherClient as EventGridPublisherClientImpl
from ._policies import CloudEventDistributedTracingPolicy
from ._version import VERSION
from ._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Union, Dict, List
from azure.core.credentials import AzureKeyCredential
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
SendType = Union[
CloudEvent,
EventGridEvent,
Expand All @@ -42,6 +44,13 @@
List[Dict]
]

ListEventType = Union[
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]


class EventGridPublisherClient(object):
"""EventGrid Python Publisher Client.
Expand Down Expand Up @@ -79,7 +88,7 @@ def _policies(credential, **kwargs):
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
CloudEventDistributedTracingPolicy(**kwargs),
CloudEventDistributedTracingPolicy(),
HttpLoggingPolicy(**kwargs)
]
return policies
Expand All @@ -98,20 +107,24 @@ def send(self, events, **kwargs):
:raises: :class:`ValueError`, when events do not follow specified SendType.
"""
if not isinstance(events, list):
events = [events]
events = cast(ListEventType, [events])

if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
try:
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
events = [cast(CloudEvent, e)._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
except AttributeError:
pass # means it's a dictionary
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
self._client.publish_cloud_event_events(
self._topic_hostname,
cast(List[InternalCloudEvent], events),
**kwargs
)
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
kwargs.setdefault("content_type", "application/json; charset=utf-8")
self._client.publish_events(self._topic_hostname, events, **kwargs)
self._client.publish_events(self._topic_hostname, cast(List[InternalEventGridEvent], events), **kwargs)
elif all(isinstance(e, CustomEvent) for e in events):
serialized_events = [dict(e) for e in events]
self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
serialized_events = [dict(e) for e in events] # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? What's the error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument 1 to "dict" has incompatible type "Union[CloudEvent, EventGridEvent, CustomEvent, Dict[Any, Any]]"; expected "Mapping[Any, Any]"
I can use a cast, but it won't entirely be true - ideally we validate that it's not a cloudevent, eventgrid event by the time we hit this line and they should not be included in the union - afaik, it's a problem with mypy.

self._client.publish_custom_event_events(self._topic_hostname, cast(List, serialized_events), **kwargs)
else:
raise ValueError("Event schema is not correct.")
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
# license information.
# -------------------------------------------------------------------------

from typing import Any, TYPE_CHECKING
KieranBrantnerMagee marked this conversation as resolved.
Show resolved Hide resolved
import six

from azure.core.pipeline.policies import SansIOHTTPPolicy

if TYPE_CHECKING:
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential


class EventGridSharedAccessSignatureCredentialPolicy(SansIOHTTPPolicy):
"""Adds a token header for the provided credential.
:param credential: The credential used to authenticate requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import Any, Union, List, Dict
from typing import Any, Union, List, Dict, cast
from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import (
Expand All @@ -26,19 +26,27 @@
from .._models import CloudEvent, EventGridEvent, CustomEvent
from .._helpers import _get_topic_hostname_only_fqdn, _get_authentication_policy, _is_cloud_event
from .._generated.aio import EventGridPublisherClient as EventGridPublisherClientAsync
from .._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent
from .._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
from .._version import VERSION

SendType = Union[
CloudEvent,
EventGridEvent,
CustomEvent,
Dict,
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]
CloudEvent,
EventGridEvent,
CustomEvent,
Dict,
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]

ListEventType = Union[
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]

class EventGridPublisherClient():
"""Asynchronous EventGrid Python Publisher Client.
Expand Down Expand Up @@ -101,20 +109,34 @@ async def send(
:raises: :class:`ValueError`, when events do not follow specified SendType.
"""
if not isinstance(events, list):
events = [events]
events = cast(ListEventType, [events])

if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
try:
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
events = [
cast(CloudEvent, e)._to_generated(**kwargs) for e in events # pylint: disable=protected-access
]
except AttributeError:
pass # means it's a dictionary
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
await self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
await self._client.publish_cloud_event_events(
self._topic_hostname,
cast(List[InternalCloudEvent], events),
**kwargs
)
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
kwargs.setdefault("content_type", "application/json; charset=utf-8")
await self._client.publish_events(self._topic_hostname, events, **kwargs)
await self._client.publish_events(
self._topic_hostname,
cast(List[InternalEventGridEvent], events),
**kwargs
)
elif all(isinstance(e, CustomEvent) for e in events):
serialized_events = [dict(e) for e in events]
await self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
serialized_events = [dict(e) for e in events] # type: ignore
await self._client.publish_custom_event_events(
self._topic_hostname,
cast(List, serialized_events),
**kwargs
)
else:
raise ValueError("Event schema is not correct.")
13 changes: 13 additions & 0 deletions sdk/eventgrid/azure-eventgrid/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[mypy]
python_version = 3.7
warn_return_any = True
warn_unused_configs = True
ignore_missing_imports = True

# Per-module options:

[mypy-azure.eventgrid._generated.*]
ignore_errors = True

[mypy-azure.core.*]
ignore_errors = True