Skip to content

Commit 62a7ca2

Browse files
authored
Merge branch 'master' into fix_passing_positional_arguments_in_transports
2 parents b29561e + 5127e8c commit 62a7ca2

File tree

6 files changed

+112
-47
lines changed

6 files changed

+112
-47
lines changed

docs/transports/websockets.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@ The websockets transport supports both:
88
- the `Apollo websockets transport protocol`_.
99
- the `GraphQL-ws websockets transport protocol`_
1010

11-
It will detect the backend supported protocol from the response http headers returned.
11+
It will propose both subprotocols to the backend and detect the supported protocol
12+
from the response http headers returned by the backend.
13+
14+
.. note::
15+
For some backends (graphql-ws before `version 5.6.1`_ without backwards compatibility), it may be necessary to specify
16+
only one subprotocol to the backend. It can be done by using
17+
:code:`subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL]`
18+
or :code:`subprotocols=[WebsocketsTransport.APOLLO_SUBPROTOCOL]` in the transport arguments.
1219

1320
This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection.
1421

@@ -118,5 +125,6 @@ Here is an example with a ping sent every 60 seconds, expecting a pong within 10
118125
pong_timeout=10,
119126
)
120127

128+
.. _version 5.6.1: https://github.com/enisdenjo/graphql-ws/releases/tag/v5.6.1
121129
.. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
122130
.. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md

gql/client.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,15 @@ async def __aenter__(self):
583583
self.session = AsyncClientSession(client=self)
584584

585585
# Get schema from transport if needed
586-
if self.fetch_schema_from_transport and not self.schema:
587-
await self.session.fetch_schema()
586+
try:
587+
if self.fetch_schema_from_transport and not self.schema:
588+
await self.session.fetch_schema()
589+
except Exception:
590+
# we don't know what type of exception is thrown here because it
591+
# depends on the underlying transport; we just make sure that the
592+
# transport is closed and re-raise the exception
593+
await self.transport.close()
594+
raise
588595

589596
return self.session
590597

@@ -605,8 +612,15 @@ def __enter__(self):
605612
self.session = SyncClientSession(client=self)
606613

607614
# Get schema from transport if needed
608-
if self.fetch_schema_from_transport and not self.schema:
609-
self.session.fetch_schema()
615+
try:
616+
if self.fetch_schema_from_transport and not self.schema:
617+
self.session.fetch_schema()
618+
except Exception:
619+
# we don't know what type of exception is thrown here because it
620+
# depends on the underlying transport; we just make sure that the
621+
# transport is closed and re-raise the exception
622+
self.transport.close()
623+
raise
610624

611625
return self.session
612626

gql/transport/websockets.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from contextlib import suppress
55
from ssl import SSLContext
6-
from typing import Any, Dict, Optional, Tuple, Union, cast
6+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
77

88
from graphql import DocumentNode, ExecutionResult, print_ast
99
from websockets.datastructures import HeadersLike
@@ -46,6 +46,7 @@ def __init__(
4646
pong_timeout: Optional[Union[int, float]] = None,
4747
answer_pings: bool = True,
4848
connect_args: Dict[str, Any] = {},
49+
subprotocols: Optional[List[Subprotocol]] = None,
4950
) -> None:
5051
"""Initialize the transport with the given parameters.
5152
@@ -71,6 +72,9 @@ def __init__(
7172
(for the graphql-ws protocol).
7273
By default: True
7374
:param connect_args: Other parameters forwarded to websockets.connect
75+
:param subprotocols: list of subprotocols sent to the
76+
backend in the 'subprotocols' http header.
77+
By default: both apollo and graphql-ws subprotocols.
7478
"""
7579

7680
super().__init__(
@@ -105,10 +109,13 @@ def __init__(
105109
"""pong_received is an asyncio Event which will fire each time
106110
a pong is received with the graphql-ws protocol"""
107111

108-
self.supported_subprotocols = [
109-
self.APOLLO_SUBPROTOCOL,
110-
self.GRAPHQLWS_SUBPROTOCOL,
111-
]
112+
if subprotocols is None:
113+
self.supported_subprotocols = [
114+
self.APOLLO_SUBPROTOCOL,
115+
self.GRAPHQLWS_SUBPROTOCOL,
116+
]
117+
else:
118+
self.supported_subprotocols = subprotocols
112119

113120
async def _wait_ack(self) -> None:
114121
"""Wait for the connection_ack message. Keep alive messages are ignored"""
@@ -272,6 +279,7 @@ def _parse_answer_graphqlws(
272279
- instead of a unidirectional keep-alive (ka) message from server to client,
273280
there is now the possibility to send bidirectional ping/pong messages
274281
- connection_ack has an optional payload
282+
- the 'error' answer type returns a list of errors instead of a single error
275283
"""
276284

277285
answer_type: str = ""
@@ -288,11 +296,11 @@ def _parse_answer_graphqlws(
288296

289297
payload = json_answer.get("payload")
290298

291-
if not isinstance(payload, dict):
292-
raise ValueError("payload is not a dict")
293-
294299
if answer_type == "next":
295300

301+
if not isinstance(payload, dict):
302+
raise ValueError("payload is not a dict")
303+
296304
if "errors" not in payload and "data" not in payload:
297305
raise ValueError(
298306
"payload does not contain 'data' or 'errors' fields"
@@ -309,8 +317,11 @@ def _parse_answer_graphqlws(
309317

310318
elif answer_type == "error":
311319

320+
if not isinstance(payload, list):
321+
raise ValueError("payload is not a list")
322+
312323
raise TransportQueryError(
313-
str(payload), query_id=answer_id, errors=[payload]
324+
str(payload[0]), query_id=answer_id, errors=payload
314325
)
315326

316327
elif answer_type in ["ping", "pong", "connection_ack"]:

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,9 @@ async def client_and_graphqlws_server(graphqlws_server):
465465
# Generate transport to connect to the server fixture
466466
path = "/graphql"
467467
url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}"
468-
sample_transport = WebsocketsTransport(url=url)
468+
sample_transport = WebsocketsTransport(
469+
url=url, subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL],
470+
)
469471

470472
async with Client(transport=sample_transport) as session:
471473

tests/test_client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,50 @@ def test_gql():
200200
client = Client(schema=schema)
201201
result = client.execute(query)
202202
assert result["user"] is None
203+
204+
205+
@pytest.mark.requests
206+
def test_sync_transport_close_on_schema_retrieval_failure():
207+
"""
208+
Ensure that the transport session is closed if an error occurs when
209+
entering the context manager (e.g., because schema retrieval fails)
210+
"""
211+
212+
from gql.transport.requests import RequestsHTTPTransport
213+
214+
transport = RequestsHTTPTransport(url="http://localhost/")
215+
client = Client(transport=transport, fetch_schema_from_transport=True)
216+
217+
try:
218+
with client:
219+
pass
220+
except Exception:
221+
# we don't care what exception is thrown, we just want to check if the
222+
# transport is closed afterwards
223+
pass
224+
225+
assert client.transport.session is None
226+
227+
228+
@pytest.mark.aiohttp
229+
@pytest.mark.asyncio
230+
async def test_async_transport_close_on_schema_retrieval_failure():
231+
"""
232+
Ensure that the transport session is closed if an error occurs when
233+
entering the context manager (e.g., because schema retrieval fails)
234+
"""
235+
236+
from gql.transport.aiohttp import AIOHTTPTransport
237+
238+
transport = AIOHTTPTransport(url="http://localhost/")
239+
client = Client(transport=transport, fetch_schema_from_transport=True)
240+
241+
try:
242+
async with client:
243+
pass
244+
except Exception:
245+
# we don't care what exception is thrown, we just want to check if the
246+
# transport is closed afterwards
247+
pass
248+
249+
assert client.transport.session is None

tests/test_graphqlws_exceptions.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import asyncio
2-
import json
3-
import types
42
from typing import List
53

64
import pytest
@@ -125,49 +123,29 @@ async def test_graphqlws_server_does_not_send_ack(
125123
pass
126124

127125

128-
invalid_payload_server_answer = (
129-
'{"type":"error","id":"1","payload":{"message":"Must provide document"}}'
126+
invalid_query_server_answer = (
127+
'{"id":"1","type":"error","payload":[{"message":"Cannot query field '
128+
'\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",'
129+
'"locations":[{"line":2,"column":3}]}]}'
130130
)
131131

132132

133-
async def server_invalid_payload(ws, path):
133+
async def server_invalid_query(ws, path):
134134
await WebSocketServerHelper.send_connection_ack(ws)
135135
result = await ws.recv()
136136
print(f"Server received: {result}")
137-
await ws.send(invalid_payload_server_answer)
137+
await ws.send(invalid_query_server_answer)
138138
await WebSocketServerHelper.wait_connection_terminate(ws)
139139
await ws.wait_closed()
140140

141141

142142
@pytest.mark.asyncio
143-
@pytest.mark.parametrize("graphqlws_server", [server_invalid_payload], indirect=True)
144-
@pytest.mark.parametrize("query_str", [invalid_query_str])
145-
async def test_graphqlws_sending_invalid_payload(
146-
event_loop, client_and_graphqlws_server, query_str
147-
):
143+
@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True)
144+
async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_server):
148145

149146
session, server = client_and_graphqlws_server
150147

151-
# Monkey patching the _send_query method to send an invalid payload
152-
153-
async def monkey_patch_send_query(
154-
self, document, variable_values=None, operation_name=None,
155-
) -> int:
156-
query_id = self.next_query_id
157-
self.next_query_id += 1
158-
159-
query_str = json.dumps(
160-
{"id": str(query_id), "type": "subscribe", "payload": "BLAHBLAH"}
161-
)
162-
163-
await self._send(query_str)
164-
return query_id
165-
166-
session.transport._send_query = types.MethodType(
167-
monkey_patch_send_query, session.transport
168-
)
169-
170-
query = gql(query_str)
148+
query = gql("{helo}")
171149

172150
with pytest.raises(TransportQueryError) as exc_info:
173151
await session.execute(query)
@@ -178,7 +156,10 @@ async def monkey_patch_send_query(
178156

179157
error = exception.errors[0]
180158

181-
assert error["message"] == "Must provide document"
159+
assert (
160+
error["message"]
161+
== 'Cannot query field "helo" on type "Query". Did you mean "hello"?'
162+
)
182163

183164

184165
not_json_answer = ["BLAHBLAH"]
@@ -188,6 +169,7 @@ async def monkey_patch_send_query(
188169
missing_id_answer_3 = ['{"type": "complete"}']
189170
data_without_payload = ['{"type": "next", "id":"1"}']
190171
error_without_payload = ['{"type": "error", "id":"1"}']
172+
error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}']
191173
payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}']
192174
empty_payload = ['{"type": "next", "id":"1", "payload": {}}']
193175
sending_bytes = [b"\x01\x02\x03"]
@@ -205,6 +187,7 @@ async def monkey_patch_send_query(
205187
data_without_payload,
206188
error_without_payload,
207189
payload_is_not_a_dict,
190+
error_with_payload_not_a_list,
208191
empty_payload,
209192
sending_bytes,
210193
],

0 commit comments

Comments
 (0)