Skip to content

Fix parameters missing from services #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,15 +379,10 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
value = _Timestamp.from_datetime(value)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
value = _Duration.from_timedelta(value)
elif wraps:
if value is None:
return b""
Expand Down Expand Up @@ -1505,6 +1500,15 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]


class _Duration(Duration):
@classmethod
def from_timedelta(
cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)
) -> "_Duration":
total_ms = delta // _1_microsecond
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
return cls(seconds, nanos)

def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)

Expand All @@ -1518,6 +1522,12 @@ def delta_to_json(delta: timedelta) -> str:


class _Timestamp(Timestamp):
@classmethod
def from_datetime(cls, dt: datetime) -> "_Timestamp":
seconds = int(dt.timestamp())
nanos = int(dt.microsecond * 1e3)
return cls(seconds, nanos)

def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
return datetime.fromtimestamp(ts, tz=timezone.utc)
Expand Down
2 changes: 1 addition & 1 deletion src/betterproto/compile/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:


def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True
*, package: str, imports: set, source_type: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
Expand Down
39 changes: 20 additions & 19 deletions src/betterproto/grpc/grpclib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@

import grpclib.const

from .._types import (
ST,
T,
)


if TYPE_CHECKING:
from grpclib.client import Channel
from grpclib.metadata import Deadline

from .._types import (
ST,
IProtoMessage,
Message,
T,
)


Value = Union[str, bytes]
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
MessageLike = Union[T, ST]
MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]


class ServiceStub(ABC):
Expand Down Expand Up @@ -65,13 +66,13 @@ def __resolve_request_kwargs(
async def _unary_unary(
self,
route: str,
request: MessageLike,
response_type: Type[T],
request: "IProtoMessage",
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> T:
) -> "T":
"""Make a unary request and return the response."""
async with self.channel.request(
route,
Expand All @@ -88,13 +89,13 @@ async def _unary_unary(
async def _unary_stream(
self,
route: str,
request: MessageLike,
response_type: Type[T],
request: "IProtoMessage",
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]:
) -> AsyncIterator["T"]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route,
Expand All @@ -111,13 +112,13 @@ async def _stream_unary(
self,
route: str,
request_iterator: MessageSource,
request_type: Type[ST],
response_type: Type[T],
request_type: Type["IProtoMessage"],
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> T:
) -> "T":
"""Make a stream request and return the response."""
async with self.channel.request(
route,
Expand All @@ -135,13 +136,13 @@ async def _stream_stream(
self,
route: str,
request_iterator: MessageSource,
request_type: Type[ST],
response_type: Type[T],
request_type: Type["IProtoMessage"],
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]:
) -> AsyncIterator["T"]:
"""
Make a stream request and return an AsyncIterator to iterate over response
messages.
Expand Down
3 changes: 3 additions & 0 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class OutputTemplate:
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
output: bool = True

@property
def package(self) -> str:
Expand Down Expand Up @@ -704,6 +705,7 @@ def __post_init__(self) -> None:

# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
)
Expand Down Expand Up @@ -768,6 +770,7 @@ def py_input_message_type(self) -> str:
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
unwrap=False,
).strip('"')

@property
Expand Down
18 changes: 10 additions & 8 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue

output_package_name = proto_file.package
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
Expand All @@ -91,6 +83,14 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)

if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False

# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
Expand All @@ -113,6 +113,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# Generate output files
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
if not output_package.output:
continue

# Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
Expand Down
13 changes: 8 additions & 5 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}

import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}

{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.services %}
import grpclib
{% endif %}

{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -96,9 +97,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
,
*
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["_MetadataLike"] = None
, metadata: Optional["MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
Expand Down Expand Up @@ -179,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %}

{% for method in service.methods %}
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()
{% else %}
Expand Down
1 change: 1 addition & 0 deletions tests/inputs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
}

services = {
"googletypes_request",
"googletypes_response",
"googletypes_response_embedded",
"service",
Expand Down
29 changes: 29 additions & 0 deletions tests/inputs/googletypes_request/googletypes_request.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
syntax = "proto3";

package googletypes_request;

import "google/protobuf/duration.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";

// Tests that google types can be used as params

service Test {
rpc SendDouble (google.protobuf.DoubleValue) returns (Input);
rpc SendFloat (google.protobuf.FloatValue) returns (Input);
rpc SendInt64 (google.protobuf.Int64Value) returns (Input);
rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input);
rpc SendInt32 (google.protobuf.Int32Value) returns (Input);
rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input);
rpc SendBool (google.protobuf.BoolValue) returns (Input);
rpc SendString (google.protobuf.StringValue) returns (Input);
rpc SendBytes (google.protobuf.BytesValue) returns (Input);
rpc SendDatetime (google.protobuf.Timestamp) returns (Input);
rpc SendTimedelta (google.protobuf.Duration) returns (Input);
rpc SendEmpty (google.protobuf.Empty) returns (Input);
}

message Input {

}
47 changes: 47 additions & 0 deletions tests/inputs/googletypes_request/test_googletypes_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from datetime import (
datetime,
timedelta,
)
from typing import (
Any,
Callable,
)

import pytest

import betterproto.lib.google.protobuf as protobuf
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_request import (
Input,
TestStub,
)


test_cases = [
(TestStub.send_double, protobuf.DoubleValue, 2.5),
(TestStub.send_float, protobuf.FloatValue, 2.5),
(TestStub.send_int64, protobuf.Int64Value, -64),
(TestStub.send_u_int64, protobuf.UInt64Value, 64),
(TestStub.send_int32, protobuf.Int32Value, -32),
(TestStub.send_u_int32, protobuf.UInt32Value, 32),
(TestStub.send_bool, protobuf.BoolValue, True),
(TestStub.send_string, protobuf.StringValue, "string"),
(TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
(TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)),
(TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)),
]


@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[Input()])
service = TestStub(channel)

await service_method(service, wrapped_value)

assert channel.requests[0]["request"] == type(wrapped_value)