Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Quart Subscription Support #3818

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 101 additions & 7 deletions strawberry/quart/views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import asyncio
import warnings
from collections.abc import AsyncGenerator, Mapping
from collections.abc import AsyncGenerator, Mapping, Sequence
from datetime import timedelta
from json.decoder import JSONDecodeError
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast
from typing_extensions import TypeGuard

from quart import Request, Response, request
from quart import Quart, Request, Response, request, websocket
from quart.ctx import has_websocket_context
from quart.views import View
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
from strawberry.http.exceptions import HTTPException
from strawberry.http.async_base_view import (
AsyncBaseHTTPView,
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.ides import GraphQL_IDE
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL

if TYPE_CHECKING:
from quart.typing import ResponseReturnValue
Expand Down Expand Up @@ -46,6 +60,34 @@ async def get_form_data(self) -> FormData:
return FormData(files=files, form=form)


class QuartWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, view: AsyncBaseHTTPView, request, ws) -> None:
super().__init__(view)
self.ws = websocket

async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[object, None]:
while True:
message = await self.ws.receive()
if type(message) is bytes:
raise NonTextMessageReceived
try:
yield self.view.decode_json(message)
except JSONDecodeError as e:
if not ignore_parsing_errors:
raise NonJsonMessageReceived from e

async def send_json(self, message: Mapping[str, object]) -> None:
try:
await self.ws.send(self.view.encode_json(message))
except asyncio.CancelledError as exc:
raise WebSocketDisconnected from exc

async def close(self, code: int, reason: str) -> None:
await self.ws.close(code, reason=reason)


class GraphQLView(
AsyncBaseHTTPView[
Request, Response, Response, Request, Response, Context, RootValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we would adjust these types to include quart.Websocket and make create_websocket_response return websocket etc. so that we end up with sound types.

That being said, I recognize quart loves their global context vars, so we could go with None here too. In that case get_context might need an update, since it's currently just using the return value from create_websocket_response which is None.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wasn't quite sure how to approach this exactly because of the context vars, But once the tests are all passing, we can take a look at cleaning this up.

Expand All @@ -55,17 +97,31 @@ class GraphQLView(
methods: ClassVar[list[str]] = ["GET", "POST"]
allow_queries_via_get: bool = True
request_adapter_class = QuartHTTPRequestAdapter
websocket_adapter_class = QuartWebSocketAdapter

def __init__(
self,
schema: "BaseSchema",
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
keep_alive: bool = True,
keep_alive_interval: float = 1,
debug: bool = False,
subscription_protocols: Sequence[str] = [
GRAPHQL_TRANSPORT_WS_PROTOCOL,
GRAPHQL_WS_PROTOCOL,
],
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
self.keep_alive = keep_alive
self.keep_alive_interval = keep_alive_interval
self.debug = debug
self.subscription_protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
self.multipart_uploads_enabled = multipart_uploads_enabled

if graphiql is not None:
Expand Down Expand Up @@ -123,15 +179,53 @@ async def create_streaming_response(
)

def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
return False
if has_websocket_context():
return True

# Check if the request is a WebSocket upgrade request
connection = request.headers.get("Connection", "").lower()
upgrade = request.headers.get("Upgrade", "").lower()

return "upgrade" in connection and "websocket" in upgrade

async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
raise NotImplementedError
# Get the requested protocols
protocols_header = websocket.headers.get("Sec-WebSocket-Protocol", "")
if not protocols_header:
return None

# Find the first matching protocol
requested_protocols = [p.strip() for p in protocols_header.split(",")]
for protocol in requested_protocols:
if protocol in self.subscription_protocols:
return protocol

return None

async def create_websocket_response(
self, request: Request, subprotocol: Optional[str]
) -> Response:
raise NotImplementedError
await websocket.accept(subprotocol=subprotocol)
# Return the current websocket context as the "response"
return None

@classmethod
def register_route(cls, app: Quart, rule_name: str, path: str, **kwargs):
"""Helper method to register both HTTP and WebSocket handlers for a given path.

Args:
app: The Quart application
rule_name: The name of the rule
path: The path to register the handlers for
**kwargs: Parameters to pass to the GraphQLView constructor
"""
# Register both HTTP and WebSocket handler at the same path
view_func = cls.as_view(rule_name, **kwargs)
app.add_url_rule(path, view_func=view_func, methods=["GET", "POST"])

# Register the WebSocket handler using the same view function
# Quart will handle routing based on the WebSocket upgrade header
app.add_url_rule(path, view_func=view_func, methods=["GET"], websocket=True)
Comment on lines +212 to +228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer if we left it up to the user to register the two URL rules. The examples in our docs for adding HTTP/WS/HTTP+WS URL rules would then end up looking more consistent.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a convenience method, but if you think that just documenting it is good enough, then I'm fine to remove it. Code that doesn't exist doesn't have any bugs :D



__all__ = ["GraphQLView"]
165 changes: 161 additions & 4 deletions tests/http/clients/quart.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,51 @@
import asyncio
import contextlib
import json
import urllib.parse
from collections.abc import AsyncGenerator, Mapping
from io import BytesIO
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import Literal

from asgiref.typing import ASGISendEvent
from hypercorn.typing import WebsocketScope

from quart import Quart
from quart import Request as QuartRequest
from quart import Response as QuartResponse
from quart.datastructures import FileStorage
from quart.testing.connections import TestWebsocketConnection
from quart.typing import TestWebsocketConnectionProtocol
from quart.utils import decode_headers
from strawberry.exceptions import ConnectionRejectionError
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
from strawberry.quart.views import GraphQLView as BaseGraphQLView
from strawberry.types import ExecutionResult
from strawberry.types.unset import UNSET, UnsetType
from tests.http.context import get_context
from tests.views.schema import Query, schema

from .base import JSON, HttpClient, Response, ResultOverrideFunction
from .base import (
JSON,
DebuggableGraphQLTransportWSHandler,
DebuggableGraphQLWSHandler,
HttpClient,
Message,
Response,
ResultOverrideFunction,
WebSocketClient,
)


class GraphQLView(BaseGraphQLView[dict[str, object], object]):
methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"]

graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler
graphql_ws_handler_class = DebuggableGraphQLWSHandler
result_override: ResultOverrideFunction = None

def __init__(self, *args: Any, **kwargs: Any):
self.result_override = kwargs.pop("result_override")
self.result_override = kwargs.pop("result_override", None)
super().__init__(*args, **kwargs)

async def get_root_value(self, request: QuartRequest) -> Query:
Expand All @@ -46,6 +67,28 @@ async def process_result(

return await super().process_result(request, result)

async def on_ws_connect(
self, context: dict[str, object]
) -> Union[UnsetType, None, dict[str, object]]:
connection_params = context["connection_params"]

if isinstance(connection_params, dict):
if connection_params.get("test-reject"):
if "err-payload" in connection_params:
raise ConnectionRejectionError(connection_params["err-payload"])
raise ConnectionRejectionError

if connection_params.get("test-accept"):
if "ack-payload" in connection_params:
return connection_params["ack-payload"]
return UNSET

if connection_params.get("test-modify"):
connection_params["modified"] = True
return UNSET

return await super().on_ws_connect(context)


class QuartHttpClient(HttpClient):
def __init__(
Expand Down Expand Up @@ -73,6 +116,23 @@ def __init__(
"/graphql",
view_func=view,
)
self.app.add_url_rule(
"/graphql", view_func=view, methods=["GET"], websocket=True
)

def create_app(self, **kwargs: Any) -> None:
self.app = Quart(__name__)
self.app.debug = True

view = GraphQLView.as_view("graphql_view", schema=schema, **kwargs)

self.app.add_url_rule(
"/graphql",
view_func=view,
)
self.app.add_url_rule(
"/graphql", view_func=view, methods=["GET"], websocket=True
)

async def _graphql_request(
self,
Expand Down Expand Up @@ -140,3 +200,100 @@ async def post(
return await self.request(
url, "post", **{k: v for k, v in kwargs.items() if v is not None}
)

@contextlib.asynccontextmanager
async def ws_connect(
self,
url: str,
*,
protocols: list[str],
) -> AsyncGenerator[WebSocketClient, None]:
headers = {
"sec-websocket-protocol": ", ".join(protocols),
}
async with self.app.test_app() as test_app:
client = test_app.test_client()
client.websocket_connection_class = QuartTestWebsocketConnection
async with client.websocket(
url, headers=headers, subprotocols=protocols
) as ws:
yield QuartWebSocketClient(ws)


class QuartTestWebsocketConnection(TestWebsocketConnection):
def __init__(self, app: Quart, scope: WebsocketScope) -> None:
scope["asgi"] = {"spec_version": "2.3"}
super().__init__(app, scope)

async def _asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "websocket.accept":
self.accepted = True
elif message["type"] == "websocket.send":
await self._receive_queue.put(message.get("bytes") or message.get("text"))
elif message["type"] == "websocket.http.response.start":
self.headers = decode_headers(message["headers"])
self.status_code = message["status"]
elif message["type"] == "websocket.http.response.body":
self.response_data.extend(message["body"])
elif message["type"] == "websocket.close":
await self._receive_queue.put(json.dumps(message))


class QuartWebSocketClient(WebSocketClient):
def __init__(self, ws: TestWebsocketConnectionProtocol):
self.ws = ws
self._closed: bool = False
self._close_code: Optional[int] = None
self._close_reason: Optional[str] = None

async def send_text(self, payload: str) -> None:
await self.ws.send(payload)

async def send_json(self, payload: Mapping[str, object]) -> None:
await self.ws.send_json(payload)

async def send_bytes(self, payload: bytes) -> None:
await self.ws.send(payload)

async def receive(self, timeout: Optional[float] = None) -> Message:
if self._closed:
# if close was received via exception, fake it so that recv works
return Message(
type="websocket.close", data=self._close_code, extra=self._close_reason
)
m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout)
if m["type"] == "websocket.close":
self._closed = True
self._close_code = m["code"]
self._close_reason = m.get("reason", None)
return Message(type=m["type"], data=m["code"], extra=m.get("reason", None))
if m["type"] == "websocket.send":
return Message(type=m["type"], data=m["text"])
if m["type"] == "connection_ack":
return Message(type=m["type"], data="")
return Message(type=m["type"], data=m["data"], extra=m["extra"])

async def receive_json(self, timeout: Optional[float] = None) -> Any:
m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout)
return m

async def close(self) -> None:
await self.ws.close(1000)
self._closed = True

@property
def accepted_subprotocol(self) -> Optional[str]:
return ""

@property
def closed(self) -> bool:
return self._closed

@property
def close_code(self) -> int:
assert self._close_code is not None
return self._close_code

@property
def close_reason(self) -> Optional[str]:
return self._close_reason
1 change: 1 addition & 0 deletions tests/websockets/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]:
("ChannelsHttpClient", "channels", [pytest.mark.channels]),
("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]),
("LitestarHttpClient", "litestar", [pytest.mark.litestar]),
("QuartHttpClient", "quart", [pytest.mark.quart]),
]:
try:
client_class = getattr(
Expand Down
1 change: 0 additions & 1 deletion tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def assert_next(

async def test_unknown_message_type(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"})

await ws.receive(timeout=2)
Expand Down
Loading