Skip to content

[WIP] Subscriptions #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ language: python
sudo: false
python:
- 2.7
- 3.4
- 3.5
- 3.6
- "pypy-5.3.1"
# - "pypy-5.3.1"
before_install:
- |
if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then
Expand All @@ -22,7 +19,9 @@ before_install:
fi
install:
- pip install -e .[test]
- pip install flake8
script:
- flake8
- py.test --cov=graphql graphql tests
after_success:
- coveralls
Expand All @@ -33,10 +32,13 @@ matrix:
- pip install pytest-asyncio
script:
- py.test --cov=graphql graphql tests tests_py35
- python: '2.7'
install: pip install flake8
- python: '3.6'
after_install:
- pip install pytest-asyncio
script:
- flake8
- py.test --cov=graphql graphql tests tests_py35
- python: '2.7'

deploy:
provider: pypi
user: syrusakbary
Expand Down
2 changes: 2 additions & 0 deletions graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
# Execute GraphQL queries.
from .execution import ( # no import order
execute,
subscribe,
ResolveInfo,
MiddlewareManager,
middlewares
Expand Down Expand Up @@ -254,6 +255,7 @@
'print_ast',
'visit',
'execute',
'subscribe',
'ResolveInfo',
'MiddlewareManager',
'middlewares',
Expand Down
3 changes: 2 additions & 1 deletion graphql/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
2) fragment "spreads" e.g. "...c"
3) inline fragment "spreads" e.g. "...on Type { a }"
"""
from .executor import execute
from .executor import execute, subscribe
from .base import ExecutionResult, ResolveInfo
from .middleware import middlewares, MiddlewareManager


__all__ = [
'execute',
'subscribe',
'ExecutionResult',
'ResolveInfo',
'MiddlewareManager',
Expand Down
40 changes: 31 additions & 9 deletions graphql/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class ExecutionContext(object):
and the fragments defined in the query document"""

__slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \
'argument_values_cache', 'executor', 'middleware', '_subfields_cache'
'argument_values_cache', 'executor', 'middleware', 'allow_subscriptions', '_subfields_cache'

def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware):
def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware, allow_subscriptions):
"""Constructs a ExecutionContext object from the arguments passed
to execute, which we will pass throughout the other execution
methods."""
Expand All @@ -32,7 +32,8 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
for definition in document_ast.definitions:
if isinstance(definition, ast.OperationDefinition):
if not operation_name and operation:
raise GraphQLError('Must provide operation name if query contains multiple operations.')
raise GraphQLError(
'Must provide operation name if query contains multiple operations.')

if not operation_name or definition.name and definition.name.value == operation_name:
operation = definition
Expand All @@ -42,18 +43,21 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val

else:
raise GraphQLError(
u'GraphQL cannot execute a request containing a {}.'.format(definition.__class__.__name__),
u'GraphQL cannot execute a request containing a {}.'.format(
definition.__class__.__name__),
definition
)

if not operation:
if operation_name:
raise GraphQLError(u'Unknown operation named "{}".'.format(operation_name))
raise GraphQLError(
u'Unknown operation named "{}".'.format(operation_name))

else:
raise GraphQLError('Must provide an operation.')

variable_values = get_variable_values(schema, operation.variable_definitions or [], variable_values)
variable_values = get_variable_values(
schema, operation.variable_definitions or [], variable_values)

self.schema = schema
self.fragments = fragments
Expand All @@ -65,6 +69,7 @@ def __init__(self, schema, document_ast, root_value, context_value, variable_val
self.argument_values_cache = {}
self.executor = executor
self.middleware = middleware
self.allow_subscriptions = allow_subscriptions
self._subfields_cache = {}

def get_field_resolver(self, field_resolver):
Expand All @@ -82,7 +87,8 @@ def get_argument_values(self, field_def, field_ast):
return result

def report_error(self, error, traceback=None):
sys.excepthook(type(error), error, getattr(error, 'stack', None) or traceback)
sys.excepthook(type(error), error, getattr(
error, 'stack', None) or traceback)
self.errors.append(error)

def get_sub_fields(self, return_type, field_asts):
Expand All @@ -101,6 +107,20 @@ def get_sub_fields(self, return_type, field_asts):
return self._subfields_cache[k]


class SubscriberExecutionContext(object):
__slots__ = 'exe_context', 'errors'

def __init__(self, exe_context):
self.exe_context = exe_context
self.errors = []

def reset(self):
self.errors = []

def __getattr__(self, name):
return getattr(self.exe_context, name)


class ExecutionResult(object):
"""The result of execution. `data` is the result of executing the
query, `errors` is null if no errors occurred, and is a
Expand Down Expand Up @@ -186,7 +206,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names
ctx, selection, runtime_type):
continue

collect_fields(ctx, runtime_type, selection.selection_set, fields, prev_fragment_names)
collect_fields(ctx, runtime_type,
selection.selection_set, fields, prev_fragment_names)

elif isinstance(selection, ast.FragmentSpread):
frag_name = selection.name.value
Expand All @@ -202,7 +223,8 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names
does_fragment_condition_match(ctx, fragment, runtime_type):
continue

collect_fields(ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names)
collect_fields(ctx, runtime_type,
fragment.selection_set, fields, prev_fragment_names)

return fields

Expand Down
121 changes: 118 additions & 3 deletions graphql/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import logging
import sys
from rx import Observable

from six import string_types
from promise import Promise, promise_for_dict, is_thenable
Expand All @@ -15,16 +16,21 @@
GraphQLSchema, GraphQLUnionType)
from .base import (ExecutionContext, ExecutionResult, ResolveInfo,
collect_fields, default_resolve_fn, get_field_def,
get_operation_root_type)
get_operation_root_type, SubscriberExecutionContext)
from .executors.sync import SyncExecutor
from .middleware import MiddlewareManager

logger = logging.getLogger(__name__)


def subscribe(*args, **kwargs):
allow_subscriptions = kwargs.pop('allow_subscriptions', True)
return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs)


def execute(schema, document_ast, root_value=None, context_value=None,
variable_values=None, operation_name=None, executor=None,
return_promise=False, middleware=None):
return_promise=False, middleware=None, allow_subscriptions=False):
assert schema, 'Must provide schema'
assert isinstance(schema, GraphQLSchema), (
'Schema must be an instance of GraphQLSchema. Also ensure that there are ' +
Expand All @@ -50,7 +56,8 @@ def execute(schema, document_ast, root_value=None, context_value=None,
variable_values,
operation_name,
executor,
middleware
middleware,
allow_subscriptions
)

def executor(v):
Expand All @@ -61,6 +68,9 @@ def on_rejected(error):
return None

def on_resolve(data):
if isinstance(data, Observable):
return data

if not context.errors:
return ExecutionResult(data=data)
return ExecutionResult(data=data, errors=context.errors)
Expand Down Expand Up @@ -88,6 +98,15 @@ def execute_operation(exe_context, operation, root_value):
if operation.operation == 'mutation':
return execute_fields_serially(exe_context, type, root_value, fields)

if operation.operation == 'subscription':
if not exe_context.allow_subscriptions:
raise Exception(
"Subscriptions are not allowed. "
"You will need to either use the subscribe function "
"or pass allow_subscriptions=True"
)
return subscribe_fields(exe_context, type, root_value, fields)

return execute_fields(exe_context, type, root_value, fields)


Expand Down Expand Up @@ -140,6 +159,44 @@ def execute_fields(exe_context, parent_type, source_value, fields):
return promise_for_dict(final_results)


def subscribe_fields(exe_context, parent_type, source_value, fields):
exe_context = SubscriberExecutionContext(exe_context)

def on_error(error):
exe_context.report_error(error)

def map_result(data):
if exe_context.errors:
result = ExecutionResult(data=data, errors=exe_context.errors)
else:
result = ExecutionResult(data=data)
exe_context.reset()
return result

observables = []

# assert len(fields) == 1, "Can only subscribe one element at a time."

for response_name, field_asts in fields.items():

result = subscribe_field(exe_context, parent_type,
source_value, field_asts)
if result is Undefined:
continue

def catch_error(error):
exe_context.errors.append(error)
return Observable.just(None)

# Map observable results
observable = result.catch_exception(catch_error).map(
lambda data: map_result({response_name: data}))
return observable
observables.append(observable)

return Observable.merge(observables)


def resolve_field(exe_context, parent_type, source, field_asts):
field_ast = field_asts[0]
field_name = field_ast.name.value
Expand Down Expand Up @@ -191,6 +248,64 @@ def resolve_field(exe_context, parent_type, source, field_asts):
)


def subscribe_field(exe_context, parent_type, source, field_asts):
field_ast = field_asts[0]
field_name = field_ast.name.value

field_def = get_field_def(exe_context.schema, parent_type, field_name)
if not field_def:
return Undefined

return_type = field_def.type
resolve_fn = field_def.resolver or default_resolve_fn

# We wrap the resolve_fn from the middleware
resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn)

# Build a dict of arguments from the field.arguments AST, using the variables scope to
# fulfill any variable references.
args = exe_context.get_argument_values(field_def, field_ast)

# The resolve function's optional third argument is a context value that
# is provided to every resolve function within an execution. It is commonly
# used to represent an authenticated user, or request-specific caches.
context = exe_context.context_value

# The resolve function's optional third argument is a collection of
# information about the current execution state.
info = ResolveInfo(
field_name,
field_asts,
return_type,
parent_type,
schema=exe_context.schema,
fragments=exe_context.fragments,
root_value=exe_context.root_value,
operation=exe_context.operation,
variable_values=exe_context.variable_values,
context=context
)

executor = exe_context.executor
result = resolve_or_error(resolve_fn_middleware,
source, info, args, executor)

if isinstance(result, Exception):
raise result

if not isinstance(result, Observable):
raise GraphQLError(
'Subscription must return Async Iterable or Observable. Received: {}'.format(repr(result)))

return result.map(functools.partial(
complete_value_catching_error,
exe_context,
return_type,
field_asts,
info,
))


def resolve_or_error(resolve_fn, source, info, args, executor):
try:
return executor.execute(resolve_fn, source, info, **args)
Expand Down
12 changes: 11 additions & 1 deletion graphql/execution/executors/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ def ensure_future(coro_or_future, loop=None):
del task._source_traceback[-1]
return task
else:
raise TypeError('A Future, a coroutine or an awaitable is required')
raise TypeError(
'A Future, a coroutine or an awaitable is required')

try:
from .asyncio_utils import asyncgen_to_observable, isasyncgen
except Exception:
def isasyncgen(obj): False

def asyncgen_to_observable(asyncgen): pass


class AsyncioExecutor(object):
Expand All @@ -50,4 +58,6 @@ def execute(self, fn, *args, **kwargs):
future = ensure_future(result, loop=self.loop)
self.futures.append(future)
return Promise.resolve(future)
elif isasyncgen(result):
return asyncgen_to_observable(result)
return result
Loading