Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flask asyncio support for dataloaders #66

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: fix up calls and added async unit tests
  • Loading branch information
Cameron Hurst committed Jan 6, 2021
commit 2955648c1d991de92ac9402ee41f5f8c9a86ae70
5 changes: 3 additions & 2 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class GraphQLView(View):
default_query = None
header_editor_enabled = None
should_persist_headers = None
enable_async = True
enable_async = False

methods = ["GET", "POST", "PUT", "DELETE"]

Expand Down Expand Up @@ -74,7 +74,7 @@ def get_middleware(self):
@staticmethod
def get_async_execution_results(execution_results):
async def await_execution_results(execution_results):
return [ex if ex is None or is_awaitable(ex) else await ex for ex in execution_results]
return [ex if ex is None or not is_awaitable(ex) else await ex for ex in execution_results]

return asyncio.run(await_execution_results(execution_results))

Expand All @@ -100,6 +100,7 @@ def dispatch_request(self):
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
run_sync=not self.enable_async,
)

if self.enable_async:
Expand Down
9 changes: 5 additions & 4 deletions tests/flask/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from flask import Flask

from graphql_server.flask import GraphQLView
from tests.flask.schema import Schema
from tests.flask.schema import AsyncSchema, Schema


def create_app(path="/graphql", **kwargs):
server = Flask(__name__)
server.debug = True
server.add_url_rule(
path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs)
)
if kwargs.get("enable_async", None):
server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=AsyncSchema, **kwargs))
else:
server.add_url_rule(path, view_func=GraphQLView.as_view("graphql", schema=Schema, **kwargs))
return server


Expand Down
25 changes: 22 additions & 3 deletions tests/flask/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,28 @@ def resolve_raises(*_):

MutationRootType = GraphQLObjectType(
name="MutationRoot",
fields={
"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)
},
fields={"writeTest": GraphQLField(type_=QueryRootType, resolve=lambda *_: QueryRootType)},
)

Schema = GraphQLSchema(QueryRootType, MutationRootType)


async def async_resolver(obj, info):
return "async"


AsyncQueryRootType = GraphQLObjectType(
name="QueryRoot",
fields={
"sync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=lambda obj, info: "sync"),
"nsync": GraphQLField(GraphQLNonNull(GraphQLString), resolve=async_resolver),
},
)
AsyncMutationRootType = GraphQLObjectType(
name="MutationRoot",
fields={
"sync": GraphQLField(type_=GraphQLString, resolve=lambda obj, info: "sync"),
"nsync": GraphQLField(type_=GraphQLString, resolve=async_resolver),
},
)
AsyncSchema = GraphQLSchema(AsyncQueryRootType, AsyncMutationRootType)
79 changes: 41 additions & 38 deletions tests/flask/test_graphqlview.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def test_allows_get_with_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}}


def test_reports_validation_errors(app, client):
Expand Down Expand Up @@ -272,7 +270,9 @@ def test_supports_post_url_encoded_query_with_string_variables(app, client):
def test_supports_post_json_query_with_get_variable_values(app, client):
response = client.post(
url_string(app, variables=json.dumps({"who": "Dolly"})),
data=json_dump_kwarg(query="query helloWho($who: String){ test(who: $who) }",),
data=json_dump_kwarg(
query="query helloWho($who: String){ test(who: $who) }",
),
content_type="application/json",
)

Expand All @@ -283,7 +283,11 @@ def test_supports_post_json_query_with_get_variable_values(app, client):
def test_post_url_encoded_query_with_get_variable_values(app, client):
response = client.post(
url_string(app, variables=json.dumps({"who": "Dolly"})),
data=urlencode(dict(query="query helloWho($who: String){ test(who: $who) }",)),
data=urlencode(
dict(
query="query helloWho($who: String){ test(who: $who) }",
)
),
content_type="application/x-www-form-urlencoded",
)

Expand Down Expand Up @@ -320,9 +324,7 @@ def test_allows_post_with_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}}


def test_allows_post_with_get_operation_name(app, client):
Expand All @@ -340,18 +342,14 @@ def test_allows_post_with_get_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
assert response_json(response) == {"data": {"test": "Hello World", "shared": "Hello Everyone"}}


@pytest.mark.parametrize("app", [create_app(pretty=True)])
def test_supports_pretty_printing(app, client):
response = client.get(url_string(app, query="{test}"))

assert response.data.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
assert response.data.decode() == ("{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}")


@pytest.mark.parametrize("app", [create_app(pretty=False)])
Expand All @@ -364,9 +362,7 @@ def test_not_pretty_by_default(app, client):
def test_supports_pretty_printing_by_request(app, client):
response = client.get(url_string(app, query="{test}", pretty="1"))

assert response.data.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
assert response.data.decode() == ("{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}")


def test_handles_field_errors_caught_by_graphql(app, client):
Expand Down Expand Up @@ -403,9 +399,7 @@ def test_handles_errors_caused_by_a_lack_of_query(app, client):

assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "Must provide query string.", "locations": None, "path": None}
]
"errors": [{"message": "Must provide query string.", "locations": None, "path": None}]
}


Expand All @@ -425,15 +419,11 @@ def test_handles_batch_correctly_if_is_disabled(app, client):


def test_handles_incomplete_json_bodies(app, client):
response = client.post(
url_string(app), data='{"query":', content_type="application/json"
)
response = client.post(url_string(app), data='{"query":', content_type="application/json")

assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "POST body sent invalid JSON.", "locations": None, "path": None}
]
"errors": [{"message": "POST body sent invalid JSON.", "locations": None, "path": None}]
}


Expand All @@ -445,9 +435,7 @@ def test_handles_plain_post_text(app, client):
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "Must provide query string.", "locations": None, "path": None}
]
"errors": [{"message": "Must provide query string.", "locations": None, "path": None}]
}


Expand All @@ -461,9 +449,7 @@ def test_handles_poorly_formed_variables(app, client):
)
assert response.status_code == 400
assert response_json(response) == {
"errors": [
{"message": "Variables are invalid JSON.", "locations": None, "path": None}
]
"errors": [{"message": "Variables are invalid JSON.", "locations": None, "path": None}]
}


Expand Down Expand Up @@ -524,9 +510,7 @@ def test_post_multipart_data(app, client):
)

assert response.status_code == 200
assert response_json(response) == {
"data": {u"writeTest": {u"test": u"Hello World"}}
}
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}


@pytest.mark.parametrize("app", [create_app(batch=True)])
Expand Down Expand Up @@ -575,6 +559,25 @@ def test_batch_allows_post_with_operation_name(app, client):
)

assert response.status_code == 200
assert response_json(response) == [
{"data": {"test": "Hello World", "shared": "Hello Everyone"}}
]
assert response_json(response) == [{"data": {"test": "Hello World", "shared": "Hello Everyone"}}]


@pytest.mark.parametrize(
("query", "result"),
(
("query sync {sync}", {"sync": "sync"}),
("query nsync {nsync}", {"nsync": "async"}),
("mutation sync {sync}", {"sync": "sync"}),
("mutation nsync {nsync}", {"nsync": "async"}),
),
)
@pytest.mark.parametrize("app", [create_app(enable_async=True)])
def test_async_client_handles_sync_calls(app, client, query, result):
response = client.post(
url_string(app),
data=json_dump_kwarg(query=query),
content_type="application/json",
)

assert response.status_code == 200, response.data
assert response_json(response) == {"data": result}