-
-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial Quart Subscription Support #3818
base: main
Are you sure you want to change the base?
Changes from all commits
b4e4668
b61023d
b01e134
d3384b8
7e75108
f8b192c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,29 @@ | ||
import asyncio | ||
import warnings | ||
from collections.abc import AsyncGenerator, Mapping | ||
from collections.abc import AsyncGenerator, Mapping, Sequence | ||
from datetime import timedelta | ||
from json.decoder import JSONDecodeError | ||
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast | ||
from typing_extensions import TypeGuard | ||
|
||
from quart import Request, Response, request | ||
from quart import Quart, Request, Response, request, websocket | ||
from quart.ctx import has_websocket_context | ||
from quart.views import View | ||
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter | ||
from strawberry.http.exceptions import HTTPException | ||
from strawberry.http.async_base_view import ( | ||
AsyncBaseHTTPView, | ||
AsyncHTTPRequestAdapter, | ||
AsyncWebSocketAdapter, | ||
) | ||
from strawberry.http.exceptions import ( | ||
HTTPException, | ||
NonJsonMessageReceived, | ||
NonTextMessageReceived, | ||
WebSocketDisconnected, | ||
) | ||
from strawberry.http.ides import GraphQL_IDE | ||
from strawberry.http.types import FormData, HTTPMethod, QueryParams | ||
from strawberry.http.typevars import Context, RootValue | ||
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL | ||
|
||
if TYPE_CHECKING: | ||
from quart.typing import ResponseReturnValue | ||
|
@@ -46,6 +60,34 @@ async def get_form_data(self) -> FormData: | |
return FormData(files=files, form=form) | ||
|
||
|
||
class QuartWebSocketAdapter(AsyncWebSocketAdapter): | ||
def __init__(self, view: AsyncBaseHTTPView, request, ws) -> None: | ||
super().__init__(view) | ||
self.ws = websocket | ||
|
||
async def iter_json( | ||
self, *, ignore_parsing_errors: bool = False | ||
) -> AsyncGenerator[object, None]: | ||
while True: | ||
message = await self.ws.receive() | ||
if type(message) is bytes: | ||
raise NonTextMessageReceived | ||
try: | ||
yield self.view.decode_json(message) | ||
except JSONDecodeError as e: | ||
if not ignore_parsing_errors: | ||
raise NonJsonMessageReceived from e | ||
|
||
async def send_json(self, message: Mapping[str, object]) -> None: | ||
try: | ||
await self.ws.send(self.view.encode_json(message)) | ||
except asyncio.CancelledError as exc: | ||
raise WebSocketDisconnected from exc | ||
|
||
async def close(self, code: int, reason: str) -> None: | ||
await self.ws.close(code, reason=reason) | ||
|
||
|
||
class GraphQLView( | ||
AsyncBaseHTTPView[ | ||
Request, Response, Response, Request, Response, Context, RootValue | ||
|
@@ -55,17 +97,31 @@ class GraphQLView( | |
methods: ClassVar[list[str]] = ["GET", "POST"] | ||
allow_queries_via_get: bool = True | ||
request_adapter_class = QuartHTTPRequestAdapter | ||
websocket_adapter_class = QuartWebSocketAdapter | ||
|
||
def __init__( | ||
self, | ||
schema: "BaseSchema", | ||
graphiql: Optional[bool] = None, | ||
graphql_ide: Optional[GraphQL_IDE] = "graphiql", | ||
allow_queries_via_get: bool = True, | ||
keep_alive: bool = True, | ||
keep_alive_interval: float = 1, | ||
debug: bool = False, | ||
subscription_protocols: Sequence[str] = [ | ||
GRAPHQL_TRANSPORT_WS_PROTOCOL, | ||
GRAPHQL_WS_PROTOCOL, | ||
], | ||
connection_init_wait_timeout: timedelta = timedelta(minutes=1), | ||
multipart_uploads_enabled: bool = False, | ||
) -> None: | ||
self.schema = schema | ||
self.allow_queries_via_get = allow_queries_via_get | ||
self.keep_alive = keep_alive | ||
self.keep_alive_interval = keep_alive_interval | ||
self.debug = debug | ||
self.subscription_protocols = subscription_protocols | ||
self.connection_init_wait_timeout = connection_init_wait_timeout | ||
self.multipart_uploads_enabled = multipart_uploads_enabled | ||
|
||
if graphiql is not None: | ||
|
@@ -123,15 +179,53 @@ async def create_streaming_response( | |
) | ||
|
||
def is_websocket_request(self, request: Request) -> TypeGuard[Request]: | ||
return False | ||
if has_websocket_context(): | ||
return True | ||
treo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Check if the request is a WebSocket upgrade request | ||
connection = request.headers.get("Connection", "").lower() | ||
upgrade = request.headers.get("Upgrade", "").lower() | ||
|
||
return "upgrade" in connection and "websocket" in upgrade | ||
|
||
async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: | ||
raise NotImplementedError | ||
# Get the requested protocols | ||
protocols_header = websocket.headers.get("Sec-WebSocket-Protocol", "") | ||
if not protocols_header: | ||
return None | ||
|
||
# Find the first matching protocol | ||
requested_protocols = [p.strip() for p in protocols_header.split(",")] | ||
for protocol in requested_protocols: | ||
if protocol in self.subscription_protocols: | ||
return protocol | ||
|
||
return None | ||
|
||
async def create_websocket_response( | ||
self, request: Request, subprotocol: Optional[str] | ||
) -> Response: | ||
raise NotImplementedError | ||
await websocket.accept(subprotocol=subprotocol) | ||
# Return the current websocket context as the "response" | ||
return None | ||
|
||
@classmethod | ||
def register_route(cls, app: Quart, rule_name: str, path: str, **kwargs): | ||
"""Helper method to register both HTTP and WebSocket handlers for a given path. | ||
|
||
Args: | ||
app: The Quart application | ||
rule_name: The name of the rule | ||
path: The path to register the handlers for | ||
**kwargs: Parameters to pass to the GraphQLView constructor | ||
""" | ||
# Register both HTTP and WebSocket handler at the same path | ||
view_func = cls.as_view(rule_name, **kwargs) | ||
app.add_url_rule(path, view_func=view_func, methods=["GET", "POST"]) | ||
|
||
# Register the WebSocket handler using the same view function | ||
# Quart will handle routing based on the WebSocket upgrade header | ||
app.add_url_rule(path, view_func=view_func, methods=["GET"], websocket=True) | ||
Comment on lines
+212
to
+228
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'd prefer if we left it up to the user to register the two URL rules. The examples in our docs for adding HTTP/WS/HTTP+WS URL rules would then end up looking more consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more of a convenience method, but if you think that just documenting it is good enough, then I'm fine to remove it. Code that doesn't exist doesn't have any bugs :D |
||
|
||
|
||
__all__ = ["GraphQLView"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we would adjust these types to include
quart.Websocket
and makecreate_websocket_response
returnwebsocket
etc. so that we end up with sound types.That being said, I recognize
quart
loves their global context vars, so we could go withNone
here too. In that caseget_context
might need an update, since it's currently just using the return value fromcreate_websocket_response
which isNone
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I wasn't quite sure how to approach this exactly because of the context vars, But once the tests are all passing, we can take a look at cleaning this up.