1
+ import asyncio
1
2
import json
2
- from typing import Any , Callable , List , Optional , Union
3
+ from functools import partial
4
+ from typing import Any , AsyncGenerator , Callable , Dict , List , Optional , Union , cast
3
5
4
6
from ariadne .constants import (
5
7
CONTENT_TYPE_JSON ,
6
8
CONTENT_TYPE_TEXT_HTML ,
7
9
CONTENT_TYPE_TEXT_PLAIN ,
8
10
DATA_TYPE_JSON ,
9
- HTTP_STATUS_200_OK ,
10
- HTTP_STATUS_400_BAD_REQUEST ,
11
11
PLAYGROUND_HTML ,
12
12
)
13
13
from ariadne .exceptions import HttpBadRequestError , HttpError , HttpMethodNotAllowedError
14
- from ariadne .executable_schema import make_executable_schema
15
14
from ariadne .types import Bindable
16
15
from channels .generic .http import AsyncHttpConsumer
17
- from graphql import GraphQLError , format_error , graphql
16
+ from channels .generic .websocket import AsyncJsonWebsocketConsumer
17
+ from graphql import GraphQLError , GraphQLSchema , format_error , graphql , parse , subscribe
18
18
from graphql .execution import ExecutionResult
19
19
import traceback
20
20
21
21
22
- class GraphQLConsumer (AsyncHttpConsumer ):
23
- def __init__ (
24
- self ,
25
- type_defs : Union [str , List [str ]],
26
- resolvers : Union [Bindable , List [Bindable ], None ] = None ,
27
- * args ,
28
- ** kwargs
29
- ):
30
- self .schema = make_executable_schema (type_defs , resolvers )
22
+ class GraphQLHTTPConsumer (AsyncHttpConsumer ):
23
+ def __init__ (self , schema : GraphQLSchema , * args , ** kwargs ):
24
+ self .schema = schema
31
25
return super ().__init__ (* args , ** kwargs )
32
26
27
+ @classmethod
28
+ def for_schema (cls , schema : GraphQLSchema ):
29
+ return partial (cls , schema )
30
+
33
31
async def handle (self , body ):
34
32
try :
35
33
await self .handle_request (body )
@@ -55,7 +53,6 @@ async def handle_http_error(self, error: HttpError) -> None:
55
53
)
56
54
57
55
async def handle_request (self , body : bytes ) -> None :
58
- print ("REQUEST" , self .scope )
59
56
if self .scope ["method" ] == "GET" :
60
57
await self .handle_get ()
61
58
if self .scope ["method" ] == "POST" :
@@ -144,3 +141,117 @@ async def return_response_from_result(self, result: ExecutionResult) -> None:
144
141
json .dumps (response ).encode ("utf-8" ),
145
142
headers = [(b"Content-Type" , CONTENT_TYPE_JSON .encode ("utf-8" ))],
146
143
)
144
+
145
+
146
+ class GraphQLWebsocketConsumer (AsyncJsonWebsocketConsumer ):
147
+ def __init__ (self , schema : GraphQLSchema , * args , ** kwargs ):
148
+ self .schema = schema
149
+ self .subscriptions : Dict [str , AsyncGenerator ] = {}
150
+ return super ().__init__ (* args , ** kwargs )
151
+
152
+ @classmethod
153
+ def for_schema (cls , schema : GraphQLSchema ):
154
+ return partial (cls , schema )
155
+
156
+ async def connect (self ):
157
+ await self .accept ("graphql-ws" )
158
+
159
+ async def receive_json (self , message : dict ):
160
+ operation_id = cast (str , message .get ("id" ))
161
+ message_type = cast (str , message .get ("type" ))
162
+ payload = cast (dict , message .get ("payload" ))
163
+
164
+ if message_type == "connection_init" :
165
+ return await self .subscription_init (operation_id , payload )
166
+ if message_type == "connection_terminate" :
167
+ return await self .subscription_terminate (operation_id )
168
+ if message_type == "start" :
169
+ await self .validate_payload (operation_id , payload )
170
+ return await self .subscription_start (operation_id , payload )
171
+ if message_type == "stop" :
172
+ return await self .subscription_stop (operation_id )
173
+ return await self .send_error (operation_id , "Unknown message type" )
174
+
175
+ async def validate_payload (self , operation_id : str , payload : dict ) -> None :
176
+ if not isinstance (payload , dict ):
177
+ return await self .send_error (operation_id , "Payload must be an object" )
178
+ query = payload .get ("query" )
179
+ if not query or not isinstance (query , str ):
180
+ return await self .send_error (operation_id , "The query must be a string." )
181
+ variables = payload .get ("variables" )
182
+ if variables is not None and not isinstance (variables , dict ):
183
+ return await self .send_error (
184
+ operation_id , "Query variables must be a null or an object."
185
+ )
186
+ operation_name = payload .get ("operationName" )
187
+ if operation_name is not None and not isinstance (operation_name , str ):
188
+ return await self .send_error (
189
+ operation_id , '"%s" is not a valid operation name.' % operation_name
190
+ )
191
+
192
+ async def send_message (
193
+ self , operation_id : str , message_type : str , payload : dict = None
194
+ ) -> None :
195
+ message : Dict [str , Any ] = {"type" : message_type }
196
+ if operation_id is not None :
197
+ message ["id" ] = operation_id
198
+ if payload is not None :
199
+ message ["payload" ] = payload
200
+ return await self .send_json (message )
201
+
202
+ async def send_result (self , operation_id : str , result : ExecutionResult ) -> None :
203
+ payload = {}
204
+ if result .data :
205
+ payload ["data" ] = result .data
206
+ if result .errors :
207
+ payload ["errors" ] = [format_error (e ) for e in result .errors ]
208
+ await self .send_message (operation_id , "data" , payload )
209
+
210
+ async def send_error (self , operation_id : str , message : str ) -> None :
211
+ payload = {"message" : message }
212
+ await self .send_message (operation_id , "error" , payload )
213
+
214
+ async def subscription_init (self , operation_id : str , payload : dict ) -> None :
215
+ await self .send_message (operation_id , "ack" )
216
+
217
+ async def subscription_start (self , operation_id : str , payload : dict ) -> None :
218
+ results = await subscribe (
219
+ self .schema ,
220
+ parse (payload ["query" ]),
221
+ root_value = self .get_query_root (payload ),
222
+ context_value = self .get_query_context (payload ),
223
+ variable_values = payload .get ("variables" ),
224
+ operation_name = payload .get ("operationName" ),
225
+ )
226
+ if isinstance (results , ExecutionResult ):
227
+ await self .send_result (operation_id , results )
228
+ await self .send_message (operation_id , "complete" )
229
+ else :
230
+ asyncio .ensure_future (self .observe (operation_id , results ))
231
+
232
+ async def subscription_stop (self , operation_id : str ) -> None :
233
+ if operation_id in self .subscriptions :
234
+ await self .subscriptions [operation_id ].aclose ()
235
+ del self .subscriptions [operation_id ]
236
+
237
+ def get_query_root (
238
+ self , request_data : dict # pylint: disable=unused-argument
239
+ ) -> Any :
240
+ """Override this method in inheriting class to create query root."""
241
+ return None
242
+
243
+ def get_query_context (
244
+ self , request_data : dict # pylint: disable=unused-argument
245
+ ) -> Any :
246
+ """Override this method in inheriting class to create query context."""
247
+ return {"scope" : self .scope }
248
+
249
+ async def observe (self , operation_id : str , results : AsyncGenerator ) -> None :
250
+ self .subscriptions [operation_id ] = results
251
+ async for result in results :
252
+ await self .send_result (operation_id , result )
253
+ await self .send_message (operation_id , "complete" , None )
254
+
255
+ async def disconnect (self , code : Any ) -> None :
256
+ for operation_id in self .subscriptions :
257
+ self .subscription_stop (operation_id )
0 commit comments