Skip to content

Commit

Permalink
feat: add interceptor-like functionality to REST transport (#1142)
Browse files Browse the repository at this point in the history
Interceptors are a gRPC feature that wraps rpcs in
continuation-passing-style pre and post method custom functions.
These can be used e.g. for logging, local caching, and tweaking
metadata.

This PR adds interceptor like functionality to the REST transport in
generated GAPICs.

The REST transport interceptors differ in a few ways:
1) They are not continuations. For each method there is a slot for a
"pre"function, and for each method with a non-empty return there is a
slot for a "post" function.
2) There is always an interceptor for each method. The default simply
does nothing.
3) Existing gRPC interceptors and the new REST interceptors are not
composable or interoperable.
  • Loading branch information
software-dov authored Jan 21, 2022
1 parent feb7b4f commit fe57eb2
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from .grpc import {{ service.name }}GrpcTransport
{% endif %}
{% if 'rest' in opts.transport %}
from .rest import {{ service.name }}RestTransport
from .rest import {{ service.name }}RestInterceptor
{% endif %}

# Compile a registry of transports.
Expand All @@ -29,6 +30,7 @@ __all__ = (
{% endif %}
{% if 'rest' in opts.transport %}
'{{ service.name }}RestTransport',
'{{ service.name }}RestInterceptor',
{% endif %}
)
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)


class {{ service.name }}RestInterceptor:
"""Interceptor for {{ service.name }}.

Interceptors are used to manipulate requests, request metadata, and responses
in arbitrary ways.
Example use cases include:
* Logging
* Verifying requests according to service or custom semantics
* Stripping extraneous information from responses

These use cases and more can be enabled by injecting an
instance of a custom subclass when constructing the {{ service.name }}RestTransport.

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(response):
logging.log(f"Received response: {response}")
{% endif %}

{% endfor %}
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
client = {{ service.client_name }}(transport=transport)


"""
{% for method in service.methods.values()|sort(attribute="name") if not(method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the request or metadata
before they are sent to the {{ service.name }} server.
"""
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response
{% endif %}

{% endfor %}


@dataclasses.dataclass
class {{service.name}}RestStub:
_session: AuthorizedSession
_host: str
_interceptor: {{ service.name }}RestInterceptor


class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.
Expand Down Expand Up @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
self._prep_wrapped_messages(client_info)

{% if service.has_lro %}
Expand Down Expand Up @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
},
{% endfor %}{# rule in method.http_options #}
]

request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not method.void %}
# Return the response
{% if method.lro %}
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% else %}
return {{method.output.ident}}.from_json(
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)

{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% if not method.http_options %}
Expand All @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

return stub

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ from google.api_core import grpc_helpers
from google.api_core import path_template
{% if service.has_lro %}
from google.api_core import future
from google.api_core import operation
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% if "rest" in opts.transport %}
Expand Down Expand Up @@ -1113,6 +1114,55 @@ def test_{{ method_name }}_rest_unset_required_fields():

{% endif %}{# required_fields #}

{% if not (method.server_streaming or method.client_streaming) %}
@pytest.mark.parametrize("null_interceptor", [True, False])
def test_{{ method_name }}_rest_interceptors(null_interceptor):
transport = transports.{{ service.name }}RestTransport(
credentials=ga_credentials.AnonymousCredentials(),
interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(),
)
client = {{ service.client_name }}(transport=transport)
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
{% if method.lro %}
mock.patch.object(operation.Operation, "_set_result_from_operation"), \
{% endif %}
{% if not method.void %}
mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
{% endif %}
mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
pre.assert_not_called()
{% if not method.void %}
post.assert_not_called()
{% endif %}

transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},}

req.return_value = Response()
req.return_value.status_code = 200
req.return_value.request = PreparedRequest()
{% if not method.void %}
req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %}
{% endif %}

request = {{ method.input.ident }}()
metadata =[
("key", "val"),
("cephalopod", "squid"),
]
pre.return_value = request, metadata
{% if not method.void %}
post.return_value = {{ method.output.ident }}
{% endif %}

client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
{% if not method.void %}
post.assert_called_once()
{% endif %}
{% endif %}{# streaming #}


def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport
{% endif %}
{% if 'rest' in opts.transport %}
from .rest import {{ service.name }}RestTransport
from .rest import {{ service.name }}RestInterceptor
{% endif %}


Expand All @@ -34,6 +35,7 @@ __all__ = (
{% endif %}
{% if 'rest' in opts.transport %}
'{{ service.name }}RestTransport',
'{{ service.name }}RestInterceptor',
{% endif %}
)
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
rest_version=requests_version,
)


class {{ service.name }}RestInterceptor:
"""Interceptor for {{ service.name }}.

Interceptors are used to manipulate requests, request metadata, and responses
in arbitrary ways.
Example use cases include:
* Logging
* Verifying requests according to service or custom semantics
* Stripping extraneous information from responses

These use cases and more can be enabled by injecting an
instance of a custom subclass when constructing the {{ service.name }}RestTransport.

.. code-block:
class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(request, metadata):
logging.log(f"Received request: {request}")
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(response):
logging.log(f"Received response: {response}")
{% endif %}

{% endfor %}
transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor())
client = {{ service.client_name }}(transport=transport)


"""
{% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %}
def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
"""Pre-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the request or metadata
before they are sent to the {{ service.name }} server.
"""
return request, metadata

{% if not method.void %}
def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}:
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to manipulate the response
after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response
{% endif %}

{% endfor %}


@dataclasses.dataclass
class {{service.name}}RestStub:
_session: AuthorizedSession
_host: str
_interceptor: {{ service.name }}RestInterceptor


class {{service.name}}RestTransport({{service.name}}Transport):
"""REST backend transport for {{ service.name }}.
Expand Down Expand Up @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool]=False,
url_scheme: str='https',
interceptor: Optional[{{ service.name }}RestInterceptor] = None,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)
self._interceptor = interceptor or {{ service.name }}RestInterceptor()
self._prep_wrapped_messages(client_info)

{% if service.has_lro %}
Expand Down Expand Up @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
},
{% endfor %}{# rule in method.http_options #}
]

request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
Expand Down Expand Up @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% if not method.void %}
# Return the response
{% if method.lro %}
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
resp = operations_pb2.Operation()
json_format.Parse(response.content, resp, ignore_unknown_fields=True)
{% else %}
return {{method.output.ident}}.from_json(
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)

{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% if not method.http_options %}
Expand All @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{{method.output.ident}}]:
stub = self._STUBS.get("{{method.name | snake_case}}")
if not stub:
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host)
stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)

return stub

Expand Down
Loading

0 comments on commit fe57eb2

Please sign in to comment.