Skip to content

Commit 5585181

Browse files
authored
fix: don't use stale session in rest transport (#1291)
* fix: don't use stale session in rest transport * add test
1 parent 3433b62 commit 5585181

File tree

4 files changed

+46
-12
lines changed
  • packages/gapic-generator/gapic

4 files changed

+46
-12
lines changed

packages/gapic-generator/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ class {{service.name}}RestTransport({{service.name}}Transport):
129129

130130
It sends JSON representations of protocol buffers over HTTP/1.1
131131
"""
132-
_STUBS: Dict[str, {{service.name}}RestStub] = {}
133132

134133

135134
{# TODO(yon-mg): handle mtls stuff if that is relevant for rest transport #}
@@ -399,13 +398,9 @@ class {{service.name}}RestTransport({{service.name}}Transport):
399398
def {{method.transport_safe_name | snake_case}}(self) -> Callable[
400399
[{{method.input.ident}}],
401400
{{method.output.ident}}]:
402-
stub = self._STUBS.get("{{method.transport_safe_name | snake_case}}")
403-
if not stub:
404-
stub = self._STUBS["{{method.transport_safe_name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
405-
406401
# The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
407402
# In C++ this would require a dynamic_cast
408-
return stub # type: ignore
403+
return self._{{method.name}}(self._session, self._host, self._interceptor) # type: ignore
409404

410405
{% endfor %}
411406

packages/gapic-generator/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,28 @@ def test_{{ service.name|snake_case }}_host_with_port(transport_name):
19411941
)
19421942
{% endwith %}
19431943

1944+
{% if 'rest' in opts.transport %}
1945+
@pytest.mark.parametrize("transport_name", [
1946+
"rest",
1947+
])
1948+
def test_{{ service.name|snake_case }}_client_transport_session_collision(transport_name):
1949+
creds1 = ga_credentials.AnonymousCredentials()
1950+
creds2 = ga_credentials.AnonymousCredentials()
1951+
client1 = {{ service.client_name }}(
1952+
credentials=creds1,
1953+
transport=transport_name,
1954+
)
1955+
client2 = {{ service.client_name }}(
1956+
credentials=creds2,
1957+
transport=transport_name,
1958+
)
1959+
{% for method in service.methods.values() %}
1960+
session1 = client1.transport.{{ method.transport_safe_name|snake_case }}._session
1961+
session2 = client2.transport.{{ method.transport_safe_name|snake_case }}._session
1962+
assert session1 != session2
1963+
{% endfor %}
1964+
{% endif -%}
1965+
19441966
{% if 'grpc' in opts.transport %}
19451967
def test_{{ service.name|snake_case }}_grpc_transport_channel():
19461968
channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials())

packages/gapic-generator/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ class {{service.name}}RestTransport({{service.name}}Transport):
129129

130130
It sends JSON representations of protocol buffers over HTTP/1.1
131131
"""
132-
_STUBS: Dict[str, {{service.name}}RestStub] = {}
133132

134133

135134
{# TODO(yon-mg): handle mtls stuff if that is relevant for rest transport #}
@@ -399,13 +398,9 @@ class {{service.name}}RestTransport({{service.name}}Transport):
399398
def {{method.transport_safe_name|snake_case}}(self) -> Callable[
400399
[{{method.input.ident}}],
401400
{{method.output.ident}}]:
402-
stub = self._STUBS.get("{{method.transport_safe_name|snake_case}}")
403-
if not stub:
404-
stub = self._STUBS["{{method.transport_safe_name|snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
405-
406401
# The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
407402
# In C++ this would require a dynamic_cast
408-
return stub # type: ignore
403+
return self._{{method.name}}(self._session, self._host, self._interceptor) # type: ignore
409404

410405
{% endfor %}
411406

packages/gapic-generator/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,28 @@ def test_{{ service.name|snake_case }}_host_with_port(transport_name):
10311031
)
10321032
{% endwith %}
10331033

1034+
{% if 'rest' in opts.transport %}
1035+
@pytest.mark.parametrize("transport_name", [
1036+
"rest",
1037+
])
1038+
def test_{{ service.name|snake_case }}_client_transport_session_collision(transport_name):
1039+
creds1 = ga_credentials.AnonymousCredentials()
1040+
creds2 = ga_credentials.AnonymousCredentials()
1041+
client1 = {{ service.client_name }}(
1042+
credentials=creds1,
1043+
transport=transport_name,
1044+
)
1045+
client2 = {{ service.client_name }}(
1046+
credentials=creds2,
1047+
transport=transport_name,
1048+
)
1049+
{% for method in service.methods.values() %}
1050+
session1 = client1.transport.{{ method.transport_safe_name|snake_case }}._session
1051+
session2 = client2.transport.{{ method.transport_safe_name|snake_case }}._session
1052+
assert session1 != session2
1053+
{% endfor %}
1054+
{% endif -%}
1055+
10341056
{% if 'grpc' in opts.transport %}
10351057
def test_{{ service.name|snake_case }}_grpc_transport_channel():
10361058
channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials())

0 commit comments

Comments
 (0)