Skip to content

Fix errors raising TransportProtocolError with the graphql-ws protocol #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def _parse_answer_graphqlws(
- instead of a unidirectional keep-alive (ka) message from server to client,
there is now the possibility to send bidirectional ping/pong messages
- connection_ack has an optional payload
- the 'error' answer type returns a list of errors instead of a single error
"""

answer_type: str = ""
Expand All @@ -288,11 +289,11 @@ def _parse_answer_graphqlws(

payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

if answer_type == "next":

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

if "errors" not in payload and "data" not in payload:
raise ValueError(
"payload does not contain 'data' or 'errors' fields"
Expand All @@ -309,8 +310,11 @@ def _parse_answer_graphqlws(

elif answer_type == "error":

if not isinstance(payload, list):
raise ValueError("payload is not a list")

raise TransportQueryError(
str(payload), query_id=answer_id, errors=[payload]
str(payload[0]), query_id=answer_id, errors=payload
)

elif answer_type in ["ping", "pong", "connection_ack"]:
Expand Down
47 changes: 15 additions & 32 deletions tests/test_graphqlws_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import json
import types
from typing import List

import pytest
Expand Down Expand Up @@ -125,49 +123,29 @@ async def test_graphqlws_server_does_not_send_ack(
pass


invalid_payload_server_answer = (
'{"type":"error","id":"1","payload":{"message":"Must provide document"}}'
invalid_query_server_answer = (
'{"id":"1","type":"error","payload":[{"message":"Cannot query field '
'\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",'
'"locations":[{"line":2,"column":3}]}]}'
)


async def server_invalid_payload(ws, path):
async def server_invalid_query(ws, path):
await WebSocketServerHelper.send_connection_ack(ws)
result = await ws.recv()
print(f"Server received: {result}")
await ws.send(invalid_payload_server_answer)
await ws.send(invalid_query_server_answer)
await WebSocketServerHelper.wait_connection_terminate(ws)
await ws.wait_closed()


@pytest.mark.asyncio
@pytest.mark.parametrize("graphqlws_server", [server_invalid_payload], indirect=True)
@pytest.mark.parametrize("query_str", [invalid_query_str])
async def test_graphqlws_sending_invalid_payload(
event_loop, client_and_graphqlws_server, query_str
):
@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True)
async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_server):

session, server = client_and_graphqlws_server

# Monkey patching the _send_query method to send an invalid payload

async def monkey_patch_send_query(
self, document, variable_values=None, operation_name=None,
) -> int:
query_id = self.next_query_id
self.next_query_id += 1

query_str = json.dumps(
{"id": str(query_id), "type": "subscribe", "payload": "BLAHBLAH"}
)

await self._send(query_str)
return query_id

session.transport._send_query = types.MethodType(
monkey_patch_send_query, session.transport
)

query = gql(query_str)
query = gql("{helo}")

with pytest.raises(TransportQueryError) as exc_info:
await session.execute(query)
Expand All @@ -178,7 +156,10 @@ async def monkey_patch_send_query(

error = exception.errors[0]

assert error["message"] == "Must provide document"
assert (
error["message"]
== 'Cannot query field "helo" on type "Query". Did you mean "hello"?'
)


not_json_answer = ["BLAHBLAH"]
Expand All @@ -188,6 +169,7 @@ async def monkey_patch_send_query(
missing_id_answer_3 = ['{"type": "complete"}']
data_without_payload = ['{"type": "next", "id":"1"}']
error_without_payload = ['{"type": "error", "id":"1"}']
error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}']
payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}']
empty_payload = ['{"type": "next", "id":"1", "payload": {}}']
sending_bytes = [b"\x01\x02\x03"]
Expand All @@ -205,6 +187,7 @@ async def monkey_patch_send_query(
data_without_payload,
error_without_payload,
payload_is_not_a_dict,
error_with_payload_not_a_list,
empty_payload,
sending_bytes,
],
Expand Down