diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 6a53d652f..c1e60ea2b 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -440,7 +440,7 @@ class ProtoClassMetadata: def __init__(self, cls: Type["Message"]): by_field = {} - by_group = {} + by_group: Dict[str, Set] = {} by_field_name = {} by_field_number = {} @@ -780,7 +780,7 @@ def FromString(cls: Type[T], data: bytes) -> T: def to_dict( self, casing: Casing = Casing.CAMEL, include_default_values: bool = False - ) -> dict: + ) -> Dict[str, Any]: """ Returns a dict representation of this message instance which can be used to serialize to e.g. JSON. Defaults to camel casing for diff --git a/betterproto/_types.py b/betterproto/_types.py index 0ff23e45e..d03432cd0 100644 --- a/betterproto/_types.py +++ b/betterproto/_types.py @@ -1,4 +1,8 @@ -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from . import Message + from grpclib._protocols import IProtoMessage # Bound type variable to allow methods to return `self` of subclasses T = TypeVar("T", bound="Message") diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py index 7218574b8..7f48fb995 100644 --- a/betterproto/grpc/grpclib_client.py +++ b/betterproto/grpc/grpclib_client.py @@ -3,9 +3,10 @@ import grpclib.const from typing import ( Any, + AsyncIterable, AsyncIterator, Collection, - Iterator, + Iterable, Mapping, Optional, Tuple, @@ -23,7 +24,7 @@ _Value = Union[str, bytes] _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] -_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]] +_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] class ServiceStub(ABC): @@ -160,7 +161,7 @@ async def _stream_stream( @staticmethod async def _send_messages(stream, messages: _MessageSource): - if hasattr(messages, "__aiter__"): + if isinstance(messages, AsyncIterable): async for message in messages: await stream.send_message(message) else: diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 85fd9057c..ed14e000d 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -6,10 +6,10 @@ import stringcase import sys import textwrap -from typing import List +from typing import List, Union +import betterproto from betterproto.casing import safe_snake_case from betterproto.compile.importing import get_ref_type -import betterproto try: # betterproto[compiler] specific dependencies @@ -58,8 +58,8 @@ def py_type( raise NotImplementedError(f"Unknown type {descriptor.type}") -def get_py_zero(type_num: int) -> str: - zero = 0 +def get_py_zero(type_num: int) -> Union[str, float]: + zero: Union[str, float] = 0 if type_num in []: zero = 0.0 elif type_num == 8: