Skip to content
This repository was archived by the owner on May 21, 2019. It is now read-only.

Commit 7c38e82

Browse files
committed
Add websocket consumer and subscriptions
1 parent cc7d61e commit 7c38e82

File tree

2 files changed

+171
-26
lines changed

2 files changed

+171
-26
lines changed

mysite/graphql.py

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
1+
import asyncio
12
import json
2-
from typing import Any, Callable, List, Optional, Union
3+
from functools import partial
4+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast
35

46
from ariadne.constants import (
57
CONTENT_TYPE_JSON,
68
CONTENT_TYPE_TEXT_HTML,
79
CONTENT_TYPE_TEXT_PLAIN,
810
DATA_TYPE_JSON,
9-
HTTP_STATUS_200_OK,
10-
HTTP_STATUS_400_BAD_REQUEST,
1111
PLAYGROUND_HTML,
1212
)
1313
from ariadne.exceptions import HttpBadRequestError, HttpError, HttpMethodNotAllowedError
14-
from ariadne.executable_schema import make_executable_schema
1514
from ariadne.types import Bindable
1615
from channels.generic.http import AsyncHttpConsumer
17-
from graphql import GraphQLError, format_error, graphql
16+
from channels.generic.websocket import AsyncJsonWebsocketConsumer
17+
from graphql import GraphQLError, GraphQLSchema, format_error, graphql, parse, subscribe
1818
from graphql.execution import ExecutionResult
1919
import traceback
2020

2121

22-
class GraphQLConsumer(AsyncHttpConsumer):
23-
def __init__(
24-
self,
25-
type_defs: Union[str, List[str]],
26-
resolvers: Union[Bindable, List[Bindable], None] = None,
27-
*args,
28-
**kwargs
29-
):
30-
self.schema = make_executable_schema(type_defs, resolvers)
22+
class GraphQLHTTPConsumer(AsyncHttpConsumer):
23+
def __init__(self, schema: GraphQLSchema, *args, **kwargs):
24+
self.schema = schema
3125
return super().__init__(*args, **kwargs)
3226

27+
@classmethod
28+
def for_schema(cls, schema: GraphQLSchema):
29+
return partial(cls, schema)
30+
3331
async def handle(self, body):
3432
try:
3533
await self.handle_request(body)
@@ -55,7 +53,6 @@ async def handle_http_error(self, error: HttpError) -> None:
5553
)
5654

5755
async def handle_request(self, body: bytes) -> None:
58-
print("REQUEST", self.scope)
5956
if self.scope["method"] == "GET":
6057
await self.handle_get()
6158
if self.scope["method"] == "POST":
@@ -144,3 +141,117 @@ async def return_response_from_result(self, result: ExecutionResult) -> None:
144141
json.dumps(response).encode("utf-8"),
145142
headers=[(b"Content-Type", CONTENT_TYPE_JSON.encode("utf-8"))],
146143
)
144+
145+
146+
class GraphQLWebsocketConsumer(AsyncJsonWebsocketConsumer):
147+
def __init__(self, schema: GraphQLSchema, *args, **kwargs):
148+
self.schema = schema
149+
self.subscriptions: Dict[str, AsyncGenerator] = {}
150+
return super().__init__(*args, **kwargs)
151+
152+
@classmethod
153+
def for_schema(cls, schema: GraphQLSchema):
154+
return partial(cls, schema)
155+
156+
async def connect(self):
157+
await self.accept("graphql-ws")
158+
159+
async def receive_json(self, message: dict):
160+
operation_id = cast(str, message.get("id"))
161+
message_type = cast(str, message.get("type"))
162+
payload = cast(dict, message.get("payload"))
163+
164+
if message_type == "connection_init":
165+
return await self.subscription_init(operation_id, payload)
166+
if message_type == "connection_terminate":
167+
return await self.subscription_terminate(operation_id)
168+
if message_type == "start":
169+
await self.validate_payload(operation_id, payload)
170+
return await self.subscription_start(operation_id, payload)
171+
if message_type == "stop":
172+
return await self.subscription_stop(operation_id)
173+
return await self.send_error(operation_id, "Unknown message type")
174+
175+
async def validate_payload(self, operation_id: str, payload: dict) -> None:
176+
if not isinstance(payload, dict):
177+
return await self.send_error(operation_id, "Payload must be an object")
178+
query = payload.get("query")
179+
if not query or not isinstance(query, str):
180+
return await self.send_error(operation_id, "The query must be a string.")
181+
variables = payload.get("variables")
182+
if variables is not None and not isinstance(variables, dict):
183+
return await self.send_error(
184+
operation_id, "Query variables must be a null or an object."
185+
)
186+
operation_name = payload.get("operationName")
187+
if operation_name is not None and not isinstance(operation_name, str):
188+
return await self.send_error(
189+
operation_id, '"%s" is not a valid operation name.' % operation_name
190+
)
191+
192+
async def send_message(
193+
self, operation_id: str, message_type: str, payload: dict = None
194+
) -> None:
195+
message: Dict[str, Any] = {"type": message_type}
196+
if operation_id is not None:
197+
message["id"] = operation_id
198+
if payload is not None:
199+
message["payload"] = payload
200+
return await self.send_json(message)
201+
202+
async def send_result(self, operation_id: str, result: ExecutionResult) -> None:
203+
payload = {}
204+
if result.data:
205+
payload["data"] = result.data
206+
if result.errors:
207+
payload["errors"] = [format_error(e) for e in result.errors]
208+
await self.send_message(operation_id, "data", payload)
209+
210+
async def send_error(self, operation_id: str, message: str) -> None:
211+
payload = {"message": message}
212+
await self.send_message(operation_id, "error", payload)
213+
214+
async def subscription_init(self, operation_id: str, payload: dict) -> None:
215+
await self.send_message(operation_id, "ack")
216+
217+
async def subscription_start(self, operation_id: str, payload: dict) -> None:
218+
results = await subscribe(
219+
self.schema,
220+
parse(payload["query"]),
221+
root_value=self.get_query_root(payload),
222+
context_value=self.get_query_context(payload),
223+
variable_values=payload.get("variables"),
224+
operation_name=payload.get("operationName"),
225+
)
226+
if isinstance(results, ExecutionResult):
227+
await self.send_result(operation_id, results)
228+
await self.send_message(operation_id, "complete")
229+
else:
230+
asyncio.ensure_future(self.observe(operation_id, results))
231+
232+
async def subscription_stop(self, operation_id: str) -> None:
233+
if operation_id in self.subscriptions:
234+
await self.subscriptions[operation_id].aclose()
235+
del self.subscriptions[operation_id]
236+
237+
def get_query_root(
238+
self, request_data: dict # pylint: disable=unused-argument
239+
) -> Any:
240+
"""Override this method in inheriting class to create query root."""
241+
return None
242+
243+
def get_query_context(
244+
self, request_data: dict # pylint: disable=unused-argument
245+
) -> Any:
246+
"""Override this method in inheriting class to create query context."""
247+
return {"scope": self.scope}
248+
249+
async def observe(self, operation_id: str, results: AsyncGenerator) -> None:
250+
self.subscriptions[operation_id] = results
251+
async for result in results:
252+
await self.send_result(operation_id, result)
253+
await self.send_message(operation_id, "complete", None)
254+
255+
async def disconnect(self, code: Any) -> None:
256+
for operation_id in self.subscriptions:
257+
self.subscription_stop(operation_id)

mysite/routing.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,33 @@
11
import asyncio
2-
from functools import partial
32

4-
from ariadne import ResolverMap
3+
from ariadne import gql, ResolverMap, Subscription
4+
from ariadne.executable_schema import make_executable_schema
55
from channels.routing import ProtocolTypeRouter, URLRouter
66
from django.conf.urls import url
7+
from graphql.pyutils import EventEmitter, EventEmitterAsyncIterator
8+
from graphql.subscription import subscribe
79

8-
from .graphql import GraphQLConsumer
10+
from .graphql import GraphQLHTTPConsumer, GraphQLWebsocketConsumer
911

12+
SCHEMA = gql(
13+
"""
14+
type Query {
15+
hello: String!
16+
}
17+
18+
type Mutation {
19+
sendMessage(message: String!): Boolean!
20+
}
21+
22+
type Subscription {
23+
messages: String!
24+
}
25+
"""
26+
)
27+
mutation = ResolverMap("Mutation")
28+
pubsub = EventEmitter()
1029
query = ResolverMap("Query")
30+
messages = Subscription("messages")
1131

1232

1333
@query.field("hello")
@@ -16,16 +36,30 @@ async def say_hello(root, info):
1636
return "Hello!"
1737

1838

39+
@mutation.field("sendMessage")
40+
async def send_message(root, info, message):
41+
pubsub.emit("message", message)
42+
return True
43+
44+
45+
@messages.subscriber
46+
def subscribe_messages(root, info):
47+
return EventEmitterAsyncIterator(pubsub, "message")
48+
49+
50+
@messages.resolver
51+
def push_message(message, info):
52+
return message
53+
54+
55+
schema = make_executable_schema(SCHEMA, [messages, mutation, query])
56+
1957
application = ProtocolTypeRouter(
2058
{
21-
"http": URLRouter(
22-
[
23-
url(
24-
r"^graphql/$",
25-
partial(GraphQLConsumer, "type Query { hello: String! }", query),
26-
)
27-
]
28-
)
59+
"http": URLRouter([url(r"^graphql/$", GraphQLHTTPConsumer.for_schema(schema))]),
60+
"websocket": URLRouter(
61+
[url(r"^graphql/$", GraphQLWebsocketConsumer.for_schema(schema))]
62+
),
2963
}
3064
)
3165

0 commit comments

Comments
 (0)