Skip to content

Commit

Permalink
feat: implement grpc transcode for rest transport and complete genera…
Browse files Browse the repository at this point in the history
…ted tests (#999)

feat: implement grpc transcode for rest transport and complete generated tests.
  • Loading branch information
kbandes authored Sep 30, 2021
1 parent 5f87973 commit ccdd17d
Show file tree
Hide file tree
Showing 12 changed files with 415 additions and 220 deletions.
14 changes: 14 additions & 0 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import collections
import dataclasses
import json
import re
from itertools import chain
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
Expand All @@ -39,6 +40,7 @@
from google.api import http_pb2
from google.api import resource_pb2
from google.api_core import exceptions # type: ignore
from google.api_core import path_template # type: ignore
from google.protobuf import descriptor_pb2 # type: ignore
from google.protobuf.json_format import MessageToDict # type: ignore

Expand Down Expand Up @@ -714,6 +716,18 @@ class HttpRule:
uri: str
body: Optional[str]

@property
def path_fields(self) -> List[Tuple[str, str]]:
"""return list of (name, template) tuples extracted from uri."""
return [(match.group("name"), match.group("template"))
for match in path_template._VARIABLE_RE.finditer(self.uri)]

@property
def sample_request(self) -> str:
"""return json dict for sample request matching the uri template."""
sample = utils.sample_from_path_fields(self.path_fields)
return json.dumps(sample)

@classmethod
def try_parse_http_rule(cls, http_rule) -> Optional['HttpRule']:
method = http_rule.WhichOneof("pattern")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
from google.auth.transport.requests import AuthorizedSession
import json # type: ignore
import grpc # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import rest_helpers # type: ignore
from google.api_core import path_template # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import operations_v1
from requests import __version__ as requests_version
from typing import Callable, Dict, Optional, Sequence, Tuple
import warnings
{% extends '_base.py.j2' %}

{% block content %}

import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple
from requests import __version__ as requests_version

{% if service.has_lro %}
from google.api_core import operations_v1
{% endif %}
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.api_core import exceptions as core_exceptions # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore

import grpc # type: ignore

from google.auth.transport.requests import AuthorizedSession

{# TODO(yon-mg): re-add python_import/ python_modules from removed diff/current grpc template code #}
{% filter sort_lines %}
{% for method in service.methods.values() %}
{{ method.input.ident.python_import }}
{{ method.output.ident.python_import }}
{{method.input.ident.python_import}}
{{method.output.ident.python_import}}
{% endfor %}
{% if opts.add_iam_methods %}
from google.iam.v1 import iam_policy_pb2 # type: ignore
from google.iam.v1 import policy_pb2 # type: ignore
{% endif %}
{% endfilter %}

from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO
from .base import {{service.name}}Transport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO


DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
Expand All @@ -40,7 +42,7 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)

class {{ service.name }}RestTransport({{ service.name }}Transport):
class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.

{{ service.meta.doc|rst(width=72, indent=4) }}
Expand All @@ -54,13 +56,15 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{# TODO(yon-mg): handle mtls stuff if that's relevant for rest transport #}
def __init__(self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
credentials: ga_credentials.Credentials = None,
credentials_file: str = None,
scopes: Sequence[str] = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
credentials: ga_credentials.Credentials=None,
credentials_file: str=None,
scopes: Sequence[str]=None,
client_cert_source_for_mtls: Callable[[
], Tuple[bytes, bytes]]=None,
quota_project_id: Optional[str]=None,
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -88,6 +92,11 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
url_scheme: the protocol scheme for the API endpoint. Normally
"https", but for testing or local servers,
"http" can be specified.
"""
# Run the base constructor
# TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc.
Expand All @@ -99,7 +108,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
client_info=client_info,
always_use_jwt_access=always_use_jwt_access,
)
self._session = AuthorizedSession(self._credentials, default_host=self.DEFAULT_HOST)
self._session = AuthorizedSession(
self._credentials, default_host=self.DEFAULT_HOST)
{% if service.has_lro %}
self._operations_client = None
{% endif %}
Expand Down Expand Up @@ -136,16 +146,17 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):

# Return the client from cache.
return self._operations_client


{% endif %}
{% for method in service.methods.values() %}
{% if method.http_opt %}

def {{ method.name|snake_case }}(self,
request: {{ method.input.ident }}, *,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> {{ method.output.ident }}:
{%- if method.http_options and not method.lro and not (method.server_streaming or method.client_streaming) %}
def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
retry: retries.Retry=gapic_v1.method.DEFAULT,
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{method.output.ident}}:
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand All @@ -168,62 +179,57 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{% endif %}
"""

{# TODO(yon-mg): refactor when implementing grpc transcoding
- parse request pb & assign body, path params
- shove leftovers into query params
- make sure dotted nested fields preserved
- format url and send the request
#}
{% if 'body' in method.http_opt %}
http_options = [
{%- for rule in method.http_options %}{
'method': '{{ rule.method }}',
'uri': '{{ rule.uri }}',
{%- if rule.body %}
'body': '{{ rule.body }}',
{%- endif %}
},
{%- endfor %}]

request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)

{% set body_spec = method.http_options[0].body %}
{%- if body_spec %}

# Jsonify the request body
{% if method.http_opt['body'] != '*' %}
body = {{ method.input.fields[method.http_opt['body']].type.ident }}.to_json(
request.{{ method.http_opt['body'] }},
body = {% if body_spec == '*' -%}
{{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['body']),
{%- else -%}
{{method.input.fields[body_spec].type.ident}}.to_json(
{{method.input.fields[body_spec].type.ident}}(
transcoded_request['body']),
{%- endif %}

including_default_value_fields=False,
use_integers_for_enums=False
)
{% else %}
body = {{ method.input.ident }}.to_json(
request,
use_integers_for_enums=False
)
{% endif %}
{% endif %}
{%- endif %}{# body_spec #}

{# TODO(yon-mg): Write helper method for handling grpc transcoding url #}
# TODO(yon-mg): need to handle grpc transcoding and parse url correctly
# current impl assumes basic case of grpc transcoding
url = 'https://{host}{{ method.http_opt['url'] }}'.format(
host=self._host,
{% for field in method.path_params %}
{{ field }}=request.{{ method.input.get_field(field).name }},
{% endfor %}
)
uri = transcoded_request['uri']
method = transcoded_request['method']

{# TODO(yon-mg): move all query param logic out of wrappers into here to handle
nested fields correctly (can't just use set of top level fields
#}
# TODO(yon-mg): handle nested fields correctly rather than using only top level fields
# not required for GCE
query_params = {}
{% for field in method.query_params | sort%}
{% if method.input.fields[field].proto3_optional %}
if {{ method.input.ident }}.{{ field }} in request:
query_params['{{ field|camel_case }}'] = request.{{ field }}
{% else %}
query_params['{{ field|camel_case }}'] = request.{{ field }}
{% endif %}
{% endfor %}
# Jsonify the query params
query_params = json.loads({{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['query_params']),
including_default_value_fields=False,
use_integers_for_enums=False
))

# Send the request
headers = dict(metadata)
headers['Content-Type'] = 'application/json'
response = self._session.{{ method.http_opt['verb'] }}(
url,
response=getattr(self._session, method)(
uri,
timeout=timeout,
headers=headers,
params=query_params,
{% if 'body' in method.http_opt %}
params=rest_helpers.flatten_query_params(query_params),
{% if body_spec %}
data=body,
{% endif %}
)
Expand All @@ -235,16 +241,50 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{% if not method.void %}

# Return the response
return {{ method.output.ident }}.from_json(
return {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)
{% endif %}
{% endif %}
{% else %}

def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{method.output.ident}}:
r"""Placeholder: Unable to implement over REST
"""
{%- if not method.http_options %}

raise RuntimeError(
"Cannot define a method without a valid 'google.api.http' annotation.")
{%- elif method.lro %}

raise NotImplementedError(
"LRO over REST is not yet defined for python client.")
{%- elif method.server_streaming or method.client_streaming %}

raise NotImplementedError(
"Streaming over REST is not yet defined for python client")
{%- else %}

raise NotImplementedError()
{%- endif %}
{%- endif %}


{% endfor %}
{%- for method in service.methods.values() %}

@ property
def {{method.name | snake_case}}(self) -> Callable[
[{{method.input.ident}}],
{{method.output.ident}}]:
return self._{{method.name | snake_case}}
{%- endfor %}


__all__ = (
__all__=(
'{{ service.name }}RestTransport',
)
{% endblock %}
Loading

0 comments on commit ccdd17d

Please sign in to comment.