Skip to content

Commit 0cf8aaa

Browse files
committed
Handle mutable default arguments cleanly
When generating code, ensure that default list/dict arguments are initialised in local scope if unspecified or `None`.
1 parent 0321160 commit 0cf8aaa

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

betterproto/plugin.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,26 @@ def generate_code(request, response):
314314
output["typing_imports"].add("Optional")
315315
break
316316

317+
# This section ensures that method arguments having a default
318+
# value that is initialised as a List/Dict (mutable) is replaced
319+
# with None and initialisation is deferred to the beginning of the
320+
# method definition. This is done so to avoid any side-effects.
321+
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
322+
mutable_default_args = []
323+
if (
324+
not method.client_streaming
325+
and input_message
326+
and input_message.get("properties")
327+
):
328+
for f in input_message.get("properties"):
329+
if f["zero"] != "None" and (
330+
f["type"].startswith("List[")
331+
or f["type"].startswith("Dict[")
332+
):
333+
output["typing_imports"].add("Optional")
334+
mutable_default_args.append((f["py_name"], f["zero"]))
335+
f["zero"] = "None"
336+
317337
data["methods"].append(
318338
{
319339
"name": method.name,
@@ -332,6 +352,7 @@ def generate_code(request, response):
332352
),
333353
"client_streaming": method.client_streaming,
334354
"server_streaming": method.server_streaming,
355+
"mutable_default_args": mutable_default_args,
335356
}
336357
)
337358

betterproto/templates/template.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
8080
{{ method.comment }}
8181

8282
{% endif %}
83+
{%- for py_name, zero in method.mutable_default_args %}
84+
{{ py_name }} = {{ py_name }} or {{ zero }}
85+
{% endfor %}
86+
8387
{% if not method.client_streaming %}
8488
request = {{ method.input }}()
8589
{% for field in method.input_message.properties %}

betterproto/tests/grpc/test_grpclib_client.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import asyncio
2+
import sys
3+
4+
import grpclib
5+
import grpclib.metadata
6+
import pytest
7+
from grpclib.testing import ChannelFor
8+
9+
from betterproto.grpc.util.async_channel import AsyncChannel
210
from betterproto.tests.output_betterproto.service.service import (
3-
DoThingResponse,
411
DoThingRequest,
12+
DoThingResponse,
513
GetThingRequest,
614
TestStub as ThingServiceClient,
715
)
8-
import grpclib
9-
from grpclib.testing import ChannelFor
10-
import pytest
11-
from betterproto.grpc.util.async_channel import AsyncChannel
1216
from .thing_service import ThingService
1317

1418

@@ -35,6 +39,20 @@ async def test_simple_service_call():
3539
await _test_client(ThingServiceClient(channel))
3640

3741

42+
@pytest.mark.asyncio
43+
@pytest.mark.skipif(
44+
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
45+
)
46+
async def test_service_call_mutable_defaults(mocker):
47+
async with ChannelFor([ThingService()]) as channel:
48+
client = ThingServiceClient(channel)
49+
spy = mocker.spy(client, "_unary_unary")
50+
await _test_client(client)
51+
comments = spy.call_args_list[-1].args[1].comments
52+
await _test_client(client)
53+
assert spy.call_args_list[-1].args[1].comments is not comments
54+
55+
3856
@pytest.mark.asyncio
3957
async def test_service_call_with_upfront_request_params():
4058
# Setting deadline

betterproto/tests/inputs/service/service.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package service;
44

55
message DoThingRequest {
66
string name = 1;
7+
repeated string comments = 2;
78
}
89

910
message DoThingResponse {

poetry.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ protobuf = "^3.12.2"
2929
pytest = "^5.4.2"
3030
pytest-asyncio = "^0.12.0"
3131
pytest-cov = "^2.9.0"
32+
pytest-mock = "^3.1.1"
3233
tox = "^3.15.1"
3334

3435
[tool.poetry.scripts]

0 commit comments

Comments
 (0)