Skip to content

Commit

Permalink
fix: remove duplicate field entries (#786)
Browse files Browse the repository at this point in the history
Fix for #778
  • Loading branch information
software-dov authored Feb 24, 2021
1 parent 35338fc commit 9f4dfa4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
request = {{ method.input.ident }}(**request)
{% if method.flattened_fields -%}{# Cross-package req and flattened fields #}
elif not request:
request = {{ method.input.ident }}({% if method.input.ident.package != method.ident.package %}{% for f in method.flattened_fields.values() %}{{ f.name }}={{ f.name }}, {% endfor %}{% endif %})
request = {{ method.input.ident }}()
{% endif -%}{# Cross-package req and flattened fields #}
{%- else %}
# Minor optimization to avoid making a copy if the user passes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,30 @@ def test_{{ method.name|snake_case }}_from_dict():
test_{{ method.name|snake_case }}(request_type=dict)


{% if not method.client_streaming -%}
def test_{{ method.name|snake_case }}_empty_call():
# This test is a coverage failsafe to make sure that totally empty calls,
# i.e. request == None and no flattened fields passed, work.
client = {{ service.client_name }}(
credentials=credentials.AnonymousCredentials(),
transport='grpc',
)

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
client.{{ method.name|snake_case }}()
call.assert_called()
_, args, _ = call.mock_calls[0]
{% if method.client_streaming %}
assert next(args[0]) == request
{% else %}
assert args[0] == {{ method.input.ident }}()
{% endif %}
{% endif -%}


@pytest.mark.asyncio
async def test_{{ method.name|snake_case }}_async(transport: str = 'grpc_asyncio', request_type={{ method.input.ident }}):
client = {{ service.async_client_name }}(
Expand Down Expand Up @@ -1276,7 +1300,7 @@ def test_{{ method.name|snake_case }}_pager():
for result in results:
assert isinstance(result, tuple)
assert tuple(type(t) for t in result) == (str, {{ method.paged_result_field.type.fields.get('value').ident }})

assert pager.get('a') is None
assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }})
{% else %}
Expand All @@ -1288,7 +1312,7 @@ def test_{{ method.name|snake_case }}_pager():
for page_, token in zip(pages, ['abc','def','ghi', '']):
assert page_.raw_page.next_page_token == token


{% endif %} {# paged methods #}
{% endfor -%} {#- method in methods for rest #}
def test_credentials_transport_error():
Expand Down Expand Up @@ -1500,7 +1524,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_client_cert_source_for_mtl
("grpc.max_receive_message_length", -1),
],
)

# Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls
# is used.
with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()):
Expand Down

0 comments on commit 9f4dfa4

Please sign in to comment.