Skip to content

Commit 706f789

Browse files
authored
Adding a new transport class to handle Phoenix channels (#100)
1 parent 0acea14 commit 706f789

11 files changed

+768
-46
lines changed

gql/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from .client import Client
22
from .gql import gql
33
from .transport.aiohttp import AIOHTTPTransport
4+
from .transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport
45
from .transport.requests import RequestsHTTPTransport
56
from .transport.websockets import WebsocketsTransport
67

78
__all__ = [
89
"gql",
910
"AIOHTTPTransport",
1011
"Client",
12+
"PhoenixChannelWebsocketsTransport",
1113
"RequestsHTTPTransport",
1214
"WebsocketsTransport",
1315
]
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import asyncio
2+
import json
3+
from typing import Dict, Optional, Tuple
4+
5+
from graphql import DocumentNode, ExecutionResult, print_ast
6+
from websockets.exceptions import ConnectionClosed
7+
8+
from .exceptions import (
9+
TransportProtocolError,
10+
TransportQueryError,
11+
TransportServerError,
12+
)
13+
from .websockets import WebsocketsTransport
14+
15+
16+
class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
17+
def __init__(
18+
self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs
19+
) -> None:
20+
self.channel_name = channel_name
21+
self.heartbeat_interval = heartbeat_interval
22+
self.subscription_ids_to_query_ids: Dict[str, int] = {}
23+
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
24+
"""Initialize the transport with the given request parameters.
25+
26+
:param channel_name Channel on the server this transport will join
27+
:param heartbeat_interval Interval in second between each heartbeat messages
28+
sent by the client
29+
"""
30+
31+
async def _send_init_message_and_wait_ack(self) -> None:
32+
"""Join the specified channel and wait for the connection ACK.
33+
34+
If the answer is not a connection_ack message, we will return an Exception.
35+
"""
36+
37+
query_id = self.next_query_id
38+
self.next_query_id += 1
39+
40+
init_message = json.dumps(
41+
{
42+
"topic": self.channel_name,
43+
"event": "phx_join",
44+
"payload": {},
45+
"ref": query_id,
46+
}
47+
)
48+
49+
await self._send(init_message)
50+
51+
# Wait for the connection_ack message or raise a TimeoutError
52+
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)
53+
54+
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
55+
56+
if answer_type != "reply":
57+
raise TransportProtocolError(
58+
"Websocket server did not return a connection ack"
59+
)
60+
61+
async def heartbeat_coro():
62+
while True:
63+
await asyncio.sleep(self.heartbeat_interval)
64+
try:
65+
query_id = self.next_query_id
66+
self.next_query_id += 1
67+
68+
await self._send(
69+
json.dumps(
70+
{
71+
"topic": "phoenix",
72+
"event": "heartbeat",
73+
"payload": {},
74+
"ref": query_id,
75+
}
76+
)
77+
)
78+
except ConnectionClosed: # pragma: no cover
79+
return
80+
81+
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())
82+
83+
async def _send_stop_message(self, query_id: int) -> None:
84+
try:
85+
await self.listeners[query_id].put(("complete", None))
86+
except KeyError: # pragma: no cover
87+
pass
88+
89+
async def _send_connection_terminate_message(self) -> None:
90+
"""Send a phx_leave message to disconnect from the provided channel.
91+
"""
92+
93+
query_id = self.next_query_id
94+
self.next_query_id += 1
95+
96+
connection_terminate_message = json.dumps(
97+
{
98+
"topic": self.channel_name,
99+
"event": "phx_leave",
100+
"payload": {},
101+
"ref": query_id,
102+
}
103+
)
104+
105+
await self._send(connection_terminate_message)
106+
107+
async def _send_query(
108+
self,
109+
document: DocumentNode,
110+
variable_values: Optional[Dict[str, str]] = None,
111+
operation_name: Optional[str] = None,
112+
) -> int:
113+
"""Send a query to the provided websocket connection.
114+
115+
We use an incremented id to reference the query.
116+
117+
Returns the used id for this query.
118+
"""
119+
120+
query_id = self.next_query_id
121+
self.next_query_id += 1
122+
123+
query_str = json.dumps(
124+
{
125+
"topic": self.channel_name,
126+
"event": "doc",
127+
"payload": {
128+
"query": print_ast(document),
129+
"variables": variable_values or {},
130+
},
131+
"ref": query_id,
132+
}
133+
)
134+
135+
await self._send(query_str)
136+
137+
return query_id
138+
139+
def _parse_answer(
140+
self, answer: str
141+
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
142+
"""Parse the answer received from the server
143+
144+
Returns a list consisting of:
145+
- the answer_type (between:
146+
'heartbeat', 'data', 'reply', 'error', 'close')
147+
- the answer id (Integer) if received or None
148+
- an execution Result if the answer_type is 'data' or None
149+
"""
150+
151+
event: str = ""
152+
answer_id: Optional[int] = None
153+
answer_type: str = ""
154+
execution_result: Optional[ExecutionResult] = None
155+
156+
try:
157+
json_answer = json.loads(answer)
158+
159+
event = str(json_answer.get("event"))
160+
161+
if event == "subscription:data":
162+
payload = json_answer.get("payload")
163+
164+
if not isinstance(payload, dict):
165+
raise ValueError("payload is not a dict")
166+
167+
subscription_id = str(payload.get("subscriptionId"))
168+
try:
169+
answer_id = self.subscription_ids_to_query_ids[subscription_id]
170+
except KeyError:
171+
raise ValueError(
172+
f"subscription '{subscription_id}' has not been registerd"
173+
)
174+
175+
result = payload.get("result")
176+
177+
if not isinstance(result, dict):
178+
raise ValueError("result is not a dict")
179+
180+
answer_type = "data"
181+
182+
execution_result = ExecutionResult(
183+
errors=payload.get("errors"), data=result.get("data")
184+
)
185+
186+
elif event == "phx_reply":
187+
answer_id = int(json_answer.get("ref"))
188+
payload = json_answer.get("payload")
189+
190+
if not isinstance(payload, dict):
191+
raise ValueError("payload is not a dict")
192+
193+
status = str(payload.get("status"))
194+
195+
if status == "ok":
196+
197+
answer_type = "reply"
198+
response = payload.get("response")
199+
200+
if isinstance(response, dict) and "subscriptionId" in response:
201+
subscription_id = str(response.get("subscriptionId"))
202+
self.subscription_ids_to_query_ids[subscription_id] = answer_id
203+
204+
elif status == "error":
205+
response = payload.get("response")
206+
207+
if isinstance(response, dict):
208+
if "errors" in response:
209+
raise TransportQueryError(
210+
str(response.get("errors")), query_id=answer_id
211+
)
212+
elif "reason" in response:
213+
raise TransportQueryError(
214+
str(response.get("reason")), query_id=answer_id
215+
)
216+
raise ValueError("reply error")
217+
218+
elif status == "timeout":
219+
raise TransportQueryError("reply timeout", query_id=answer_id)
220+
221+
elif event == "phx_error":
222+
raise TransportServerError("Server error")
223+
elif event == "phx_close":
224+
answer_type = "close"
225+
else:
226+
raise ValueError
227+
228+
except ValueError as e:
229+
raise TransportProtocolError(
230+
"Server did not return a GraphQL result"
231+
) from e
232+
233+
return answer_type, answer_id, execution_result
234+
235+
async def _handle_answer(
236+
self,
237+
answer_type: str,
238+
answer_id: Optional[int],
239+
execution_result: Optional[ExecutionResult],
240+
) -> None:
241+
if answer_type == "close":
242+
await self.close()
243+
else:
244+
await super()._handle_answer(answer_type, answer_id, execution_result)
245+
246+
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
247+
if self.heartbeat_task is not None:
248+
self.heartbeat_task.cancel()
249+
250+
await super()._close_coro(e, clean_close)

gql/transport/websockets.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,19 +371,25 @@ async def _receive_data_loop(self) -> None:
371371
await self._fail(e, clean_close=False)
372372
break
373373

374-
try:
375-
# Put the answer in the queue
376-
if answer_id is not None:
377-
await self.listeners[answer_id].put(
378-
(answer_type, execution_result)
379-
)
380-
except KeyError:
381-
# Do nothing if no one is listening to this query_id.
382-
pass
374+
await self._handle_answer(answer_type, answer_id, execution_result)
383375

384376
finally:
385377
log.debug("Exiting _receive_data_loop()")
386378

379+
async def _handle_answer(
380+
self,
381+
answer_type: str,
382+
answer_id: Optional[int],
383+
execution_result: Optional[ExecutionResult],
384+
) -> None:
385+
try:
386+
# Put the answer in the queue
387+
if answer_id is not None:
388+
await self.listeners[answer_id].put((answer_type, execution_result))
389+
except KeyError:
390+
# Do nothing if no one is listening to this query_id.
391+
pass
392+
387393
async def subscribe(
388394
self,
389395
document: DocumentNode,

tests/conftest.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ async def stop(self):
136136

137137
print("Server stopped\n\n\n")
138138

139+
140+
class WebSocketServerHelper:
139141
@staticmethod
140142
async def send_complete(ws, query_id):
141143
await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}')
@@ -165,6 +167,26 @@ async def wait_connection_terminate(ws):
165167
assert json_result["type"] == "connection_terminate"
166168

167169

170+
class PhoenixChannelServerHelper:
171+
@staticmethod
172+
async def send_close(ws):
173+
await ws.send('{"event":"phx_close"}')
174+
175+
@staticmethod
176+
async def send_connection_ack(ws):
177+
178+
# Line return for easy debugging
179+
print("")
180+
181+
# Wait for init
182+
result = await ws.recv()
183+
json_result = json.loads(result)
184+
assert json_result["event"] == "phx_join"
185+
186+
# Send ack
187+
await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}')
188+
189+
168190
def get_server_handler(request):
169191
"""Get the server handler.
170192
@@ -181,7 +203,7 @@ def get_server_handler(request):
181203
async def default_server_handler(ws, path):
182204

183205
try:
184-
await WebSocketServer.send_connection_ack(ws)
206+
await WebSocketServerHelper.send_connection_ack(ws)
185207
query_id = 1
186208

187209
for answer in answers:
@@ -195,10 +217,10 @@ async def default_server_handler(ws, path):
195217
formatted_answer = answer
196218

197219
await ws.send(formatted_answer)
198-
await WebSocketServer.send_complete(ws, query_id)
220+
await WebSocketServerHelper.send_complete(ws, query_id)
199221
query_id += 1
200222

201-
await WebSocketServer.wait_connection_terminate(ws)
223+
await WebSocketServerHelper.wait_connection_terminate(ws)
202224
await ws.wait_closed()
203225
except ConnectionClosed:
204226
pass

tests/test_async_client_validation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from gql import Client, gql
99
from gql.transport.websockets import WebsocketsTransport
1010

11-
from .conftest import MS, WebSocketServer
11+
from .conftest import MS, WebSocketServerHelper
1212
from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef
1313

1414
starwars_expected_one = {
@@ -25,7 +25,7 @@
2525

2626

2727
async def server_starwars(ws, path):
28-
await WebSocketServer.send_connection_ack(ws)
28+
await WebSocketServerHelper.send_connection_ack(ws)
2929

3030
try:
3131
await ws.recv()
@@ -42,8 +42,8 @@ async def server_starwars(ws, path):
4242
await ws.send(data)
4343
await asyncio.sleep(2 * MS)
4444

45-
await WebSocketServer.send_complete(ws, 1)
46-
await WebSocketServer.wait_connection_terminate(ws)
45+
await WebSocketServerHelper.send_complete(ws, 1)
46+
await WebSocketServerHelper.wait_connection_terminate(ws)
4747

4848
except websockets.exceptions.ConnectionClosedOK:
4949
pass

0 commit comments

Comments
 (0)