Skip to content
This repository was archived by the owner on May 21, 2019. It is now read-only.

Commit 121de84

Browse files
committed
Run with unmodified Ariadne
1 parent 7c38e82 commit 121de84

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

mysite/routing.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import asyncio
22

3-
from ariadne import gql, ResolverMap, Subscription
3+
from ariadne import gql, ResolverMap
44
from ariadne.executable_schema import make_executable_schema
55
from channels.routing import ProtocolTypeRouter, URLRouter
66
from django.conf.urls import url
77
from graphql.pyutils import EventEmitter, EventEmitterAsyncIterator
88
from graphql.subscription import subscribe
99

1010
from .graphql import GraphQLHTTPConsumer, GraphQLWebsocketConsumer
11+
from .subscription import SubscriptionAwareResolverMap
1112

1213
SCHEMA = gql(
1314
"""
@@ -27,7 +28,7 @@
2728
mutation = ResolverMap("Mutation")
2829
pubsub = EventEmitter()
2930
query = ResolverMap("Query")
30-
messages = Subscription("messages")
31+
subscription = SubscriptionAwareResolverMap("Subscription")
3132

3233

3334
@query.field("hello")
@@ -42,17 +43,18 @@ async def send_message(root, info, message):
4243
return True
4344

4445

45-
@messages.subscriber
46+
@subscription.subscription("messages")
4647
def subscribe_messages(root, info):
4748
return EventEmitterAsyncIterator(pubsub, "message")
4849

4950

50-
@messages.resolver
51+
@subscription.field("messages")
5152
def push_message(message, info):
5253
return message
5354

5455

55-
schema = make_executable_schema(SCHEMA, [messages, mutation, query])
56+
schema = make_executable_schema(SCHEMA, [mutation, query, subscription])
57+
5658

5759
application = ProtocolTypeRouter(
5860
{

mysite/subscription.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Any, AsyncIterator, Callable, Dict, overload
2+
3+
from ariadne import ResolverMap
4+
from graphql import GraphQLSchema
5+
6+
Subscriber = Callable[..., AsyncIterator]
7+
8+
9+
class SubscriptionAwareResolverMap(ResolverMap):
10+
_subscribers: Dict[str, Subscriber]
11+
12+
def __init__(self, name: str):
13+
super().__init__(name)
14+
self._subscribers = {}
15+
16+
@overload
17+
def subscription(self, name: str) -> Callable[[Subscriber], Subscriber]:
18+
pass # pragma: no cover
19+
20+
@overload
21+
def subscription( # pylint: disable=function-redefined
22+
self, name: str, *, subscriber: Subscriber
23+
) -> Subscriber: # pylint: disable=function-redefined
24+
pass # pragma: no cover
25+
26+
def subscription(
27+
self, name, *, subscriber=None
28+
): # pylint: disable=function-redefined
29+
if not subscriber:
30+
return self.create_register_subscriber(name)
31+
self._subscribers[name] = subscriber
32+
return subscriber
33+
34+
def create_register_subscriber(
35+
self, name: str
36+
) -> Callable[[Subscriber], Subscriber]:
37+
def register_subscriber(f: Subscriber) -> Subscriber:
38+
self._subscribers[name] = f
39+
return f
40+
41+
return register_subscriber
42+
43+
def bind_to_schema(self, schema: GraphQLSchema) -> None:
44+
super().bind_to_schema(schema)
45+
graphql_type = schema.type_map.get(self.name)
46+
for field, subscriber in self._subscribers.items():
47+
if field not in graphql_type.fields:
48+
raise ValueError(
49+
"Field %s is not defined on type %s" % (field, self.name)
50+
)
51+
graphql_type.fields[field].subscribe = subscriber

0 commit comments

Comments
 (0)