Skip to content

Commit

Permalink
feat: bypass request copying in method calls (#557)
Browse files Browse the repository at this point in the history
If a proto-plus message is passed in, do not copy it.
  • Loading branch information
software-dov authored Jul 28, 2020
1 parent 1d08e60 commit 3a23143
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -287,24 +287,28 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
# Minor optimization to avoid making a copy if the user passes
# in a {{ method.input.ident }}.
# There's no risk of modifying the input as we've already verified
# there are no flattened fields.
if not isinstance(request, {{ method.input.ident }}):
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,15 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():


{% for method in service.methods.values() -%}
def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = {{ method.input.ident }}()
request = request_type()
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% if method.client_streaming %}
assert next(args[0]) == request
{% else %}
assert args[0] == request
assert args[0] == {{ method.input.ident }}()
{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -275,6 +275,11 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% endfor %}
{% endif %}


def test_{{ method.name|snake_case }}_from_dict():
test_{{ method.name|snake_case }}(request_type=dict)


{% if method.field_headers and not method.client_streaming %}
def test_{{ method.name|snake_case }}_field_headers():
client = {{ service.client_name }}(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% if method.flattened_fields -%}
# Sanity check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
if request is not None and any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}]):
has_flattened_params = any([{{ method.flattened_fields.values()|join(', ', attribute='name') }}])
if request is not None and has_flattened_params:
raise ValueError('If the `request` argument is set, then none of '
'the individual field arguments should be set.')

Expand All @@ -297,24 +298,29 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
request = {{ method.input.ident }}(request)
# Minor optimization to avoid making a copy if the user passes
# in a {{ method.input.ident }}.
# There's no risk of modifying the input as we've already verified
# there are no flattened fields.
if not isinstance(request, {{ method.input.ident }}):
request = {{ method.input.ident }}(request)
{% endif %} {# different request package #}

{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}
{#- Vanilla python protobuf wrapper types cannot _set_ repeated fields #}
{% if method.flattened_fields -%}
# If we have keyword arguments corresponding to fields on the
# request, apply these.
{% endif -%}
{%- for key, field in method.flattened_fields.items() if not(field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }} is not None:
request.{{ key }} = {{ field.name }}
{%- endfor %}
{# They can be _extended_, however -#}
{%- for key, field in method.flattened_fields.items() if (field.repeated and method.input.ident.package != method.ident.package) %}
if {{ field.name }}:
request.{{ key }}.extend({{ field.name }})
{%- endfor %}
{%- endif %}

# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():


{% for method in service.methods.values() -%}
def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
def test_{{ method.name|snake_case }}(transport: str = 'grpc', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = {{ method.input.ident }}()
request = request_type()
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% if method.client_streaming %}
assert next(args[0]) == request
{% else %}
assert args[0] == request
assert args[0] == {{ method.input.ident }}()
{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -374,6 +374,10 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
{% endif %}


def test_{{ method.name|snake_case }}_from_dict():
test_{{ method.name|snake_case }}(request_type=dict)


@pytest.mark.asyncio
async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio'):
client = {{ service.async_client_name }}(
Expand Down

0 comments on commit 3a23143

Please sign in to comment.