From 78d1a1aa82c7f3e26dda59ed9cc3f47255b1ed5a Mon Sep 17 00:00:00 2001 From: kedod Date: Sat, 2 Mar 2024 15:05:23 +0100 Subject: [PATCH] feat: expose websocket_class (#3152) feat: expose websocket_class to the other layers --- docs/examples/websockets/custom_websocket.py | 19 +++ docs/usage/websockets.rst | 22 ++++ litestar/app.py | 3 +- litestar/controller.py | 7 +- .../handlers/websocket_handlers/listener.py | 14 +++ .../websocket_handlers/route_handler.py | 18 +++ litestar/router.py | 7 +- litestar/routes/websocket.py | 5 +- tests/unit/test_websocket_class_resolution.py | 115 ++++++++++++++++++ 9 files changed, 206 insertions(+), 4 deletions(-) create mode 100644 docs/examples/websockets/custom_websocket.py create mode 100644 tests/unit/test_websocket_class_resolution.py diff --git a/docs/examples/websockets/custom_websocket.py b/docs/examples/websockets/custom_websocket.py new file mode 100644 index 0000000000..954a3bdee0 --- /dev/null +++ b/docs/examples/websockets/custom_websocket.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from litestar import Litestar, WebSocket, websocket_listener +from litestar.types.asgi_types import WebSocketMode + + +class CustomWebSocket(WebSocket): + async def receive_data(self, mode: WebSocketMode) -> str | bytes: + """Return fixed response for every websocket message.""" + await super().receive_data(mode=mode) + return "Fixed response" + + +@websocket_listener("/") +async def handler(data: str) -> str: + return data + + +app = Litestar([handler], websocket_class=CustomWebSocket) diff --git a/docs/usage/websockets.rst b/docs/usage/websockets.rst index cc43904f17..6cbbc8c267 100644 --- a/docs/usage/websockets.rst +++ b/docs/usage/websockets.rst @@ -249,3 +249,25 @@ encapsulate more complex logic. .. literalinclude:: /examples/websockets/listener_class_based_async.py :language: python + + +Custom WebSocket +---------------- + +.. versionadded:: 2.7.0 + +Litestar supports custom ``websocket_class`` instances, which can be used to further configure the default :class:`WebSocket`. +The example below illustrates how to implement custom websocket class for the whole application. + +.. dropdown:: Example of a custom websocket at the application level + + .. literalinclude:: /examples/websockets/custom_websocket.py + :language: python + +.. admonition:: Layered architecture + + WebSocket classes are part of Litestar's layered architecture, which means you can + set a websocket class on every layer of the application. If you have set a webscoket + class on multiple layers, the layer closest to the route handler will take precedence. + + You can read more about this in the :ref:`usage/applications:layered architecture` section diff --git a/litestar/app.py b/litestar/app.py index 233627da54..b9089f78fa 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -418,7 +418,7 @@ def __init__( self.state = config.state self._static_files_config = config.static_files_config self.template_engine = config.template_config.engine_instance if config.template_config else None - self.websocket_class = config.websocket_class or WebSocket + self.websocket_class: type[WebSocket] = config.websocket_class or WebSocket self.debug = config.debug self.pdb_on_exception: bool = config.pdb_on_exception self.include_in_schema = include_in_schema @@ -463,6 +463,7 @@ def __init__( type_encoders=config.type_encoders, type_decoders=config.type_decoders, include_in_schema=config.include_in_schema, + websocket_class=self.websocket_class, ) for route_handler in config.route_handlers: diff --git a/litestar/controller.py b/litestar/controller.py index 33f9da9591..d2221ea90b 100644 --- a/litestar/controller.py +++ b/litestar/controller.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: - from litestar.connection import Request + from litestar.connection import Request, WebSocket from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.spec import SecurityRequirement @@ -72,6 +72,7 @@ class Controller: "tags", "type_encoders", "type_decoders", + "websocket_class", ) after_request: AfterRequestHookHandler | None @@ -160,6 +161,10 @@ class Controller: """A mapping of types to callables that transform them into types supported for serialization.""" type_decoders: TypeDecodersSequence | None """A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.""" + websocket_class: type[WebSocket] | None + """A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default websocket for all route + handlers under the controller. + """ def __init__(self, owner: Router) -> None: """Initialize a controller. diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 4d6bc99e81..7133ddba99 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -61,6 +61,7 @@ class WebsocketListenerRouteHandler(WebsocketRouteHandler): "connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept", "on_accept": "Callback invoked after a WebSocket connection has been accepted", "on_disconnect": "Callback invoked after a WebSocket connection has been closed", + "weboscket_class": "WebSocket class", "_connection_lifespan": None, "_handle_receive": None, "_handle_send": None, @@ -86,6 +87,7 @@ def __init__( return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: ... @@ -110,6 +112,7 @@ def __init__( return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: ... @@ -134,6 +137,7 @@ def __init__( return_dto: type[AbstractDTO] | None | EmptyType = Empty, signature_namespace: Mapping[str, Any] | None = None, type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: """Initialize ``WebsocketRouteHandler`` @@ -168,6 +172,8 @@ def __init__( modelling. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. """ if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]): raise ImproperlyConfiguredException( @@ -185,6 +191,7 @@ def __init__( self.on_accept = ensure_async_callable(on_accept) if on_accept else None self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None self.type_encoders = type_encoders + self.websocket_class = websocket_class listener_dependencies = dict(dependencies or {}) @@ -209,6 +216,7 @@ def __init__( signature_namespace=signature_namespace, dto=dto, return_dto=return_dto, + websocket_class=websocket_class, **kwargs, ) @@ -346,6 +354,11 @@ class WebsocketListener(ABC): """ type_encoders: A mapping of types to callables that transform them into types supported for serialization. """ + websocket_class: type[WebSocket] | None = None + """ + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ def __init__(self, owner: Router) -> None: """Initialize a WebsocketListener instance. @@ -372,6 +385,7 @@ def to_handler(self) -> WebsocketListenerRouteHandler: return_dto=self.return_dto, signature_namespace=self.signature_namespace, type_encoders=self.type_encoders, + websocket_class=self.websocket_class, )(self.on_receive) handler.owner = self._owner return handler diff --git a/litestar/handlers/websocket_handlers/route_handler.py b/litestar/handlers/websocket_handlers/route_handler.py index 850cf59c3c..edb49c3030 100644 --- a/litestar/handlers/websocket_handlers/route_handler.py +++ b/litestar/handlers/websocket_handlers/route_handler.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Mapping +from litestar.connection import WebSocket from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import BaseRouteHandler from litestar.types.builtin_types import NoneType @@ -28,6 +29,7 @@ def __init__( name: str | None = None, opt: dict[str, Any] | None = None, signature_namespace: Mapping[str, Any] | None = None, + websocket_class: type[WebSocket] | None = None, **kwargs: Any, ) -> None: """Initialize ``WebsocketRouteHandler`` @@ -46,7 +48,10 @@ def __init__( signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. type_encoders: A mapping of types to callables that transform them into types supported for serialization. **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. """ + self.websocket_class = websocket_class super().__init__( path=path, @@ -60,6 +65,19 @@ def __init__( **kwargs, ) + def resolve_websocket_class(self) -> type[WebSocket]: + """Return the closest custom WebSocket class in the owner graph or the default Websocket class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`WebSocket <.connection.WebSocket>` class for the route handler. + """ + return next( + (layer.websocket_class for layer in reversed(self.ownership_layers) if layer.websocket_class is not None), + WebSocket, + ) + def _validate_handler_function(self) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" super()._validate_handler_function() diff --git a/litestar/router.py b/litestar/router.py index 75f5c6e876..e679b46a36 100644 --- a/litestar/router.py +++ b/litestar/router.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: - from litestar.connection import Request + from litestar.connection import Request, WebSocket from litestar.datastructures import CacheControlHeader, ETag from litestar.dto import AbstractDTO from litestar.openapi.spec import SecurityRequirement @@ -78,6 +78,7 @@ class Router: "tags", "type_encoders", "type_decoders", + "websocket_class", ) def __init__( @@ -109,6 +110,7 @@ def __init__( tags: Sequence[str] | None = None, type_encoders: TypeEncodersMap | None = None, type_decoders: TypeDecodersSequence | None = None, + websocket_class: type[WebSocket] | None = None, ) -> None: """Initialize a ``Router``. @@ -161,6 +163,8 @@ def __init__( application. type_encoders: A mapping of types to callables that transform them into types supported for serialization. type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default for + all route handlers, controllers and other routers associated with the router instance. """ self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore @@ -192,6 +196,7 @@ def __init__( self.registered_route_handler_ids: set[int] = set() self.type_encoders = dict(type_encoders) if type_encoders is not None else None self.type_decoders = list(type_decoders) if type_decoders is not None else None + self.websocket_class = websocket_class for route_handler in route_handlers or []: self.register(value=route_handler) diff --git a/litestar/routes/websocket.py b/litestar/routes/websocket.py index 9b309fe107..ebf4959d46 100644 --- a/litestar/routes/websocket.py +++ b/litestar/routes/websocket.py @@ -54,11 +54,14 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N Returns: None """ - websocket: WebSocket[Any, Any, Any] = scope["app"].websocket_class(scope=scope, receive=receive, send=send) if not self.handler_parameter_model: # pragma: no cover raise ImproperlyConfiguredException("handler parameter model not defined") + websocket: WebSocket[Any, Any, Any] = self.route_handler.resolve_websocket_class()( + scope=scope, receive=receive, send=send + ) + if self.route_handler.resolve_guards(): await self.route_handler.authorize_connection(connection=websocket) diff --git a/tests/unit/test_websocket_class_resolution.py b/tests/unit/test_websocket_class_resolution.py new file mode 100644 index 0000000000..560e0bb651 --- /dev/null +++ b/tests/unit/test_websocket_class_resolution.py @@ -0,0 +1,115 @@ +from typing import Type, Union + +import pytest + +from litestar import Controller, Litestar, Router, WebSocket +from litestar.handlers.websocket_handlers.listener import WebsocketListener, websocket_listener + +RouterWebSocket: Type[WebSocket] = type("RouterWebSocket", (WebSocket,), {}) +ControllerWebSocket: Type[WebSocket] = type("ControllerWebSocket", (WebSocket,), {}) +AppWebSocket: Type[WebSocket] = type("AppWebSocket", (WebSocket,), {}) +HandlerWebSocket: Type[WebSocket] = type("HandlerWebSocket", (WebSocket,), {}) + + +@pytest.mark.parametrize( + "handler_websocket_class, controller_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected", + ( + (HandlerWebSocket, ControllerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket), + (None, ControllerWebSocket, RouterWebSocket, AppWebSocket, True, ControllerWebSocket), + (None, None, RouterWebSocket, AppWebSocket, True, RouterWebSocket), + (None, None, None, AppWebSocket, True, AppWebSocket), + (None, None, None, None, True, WebSocket), + (None, None, None, None, False, WebSocket), + ), + ids=( + "Custom class for all layers", + "Custom class for all above handler layer", + "Custom class for all above controller layer", + "Custom class for all above router layer", + "No custom class for layers", + "No default class in app", + ), +) +def test_websocket_class_resolution_of_layers( + handler_websocket_class: Union[Type[WebSocket], None], + controller_websocket_class: Union[Type[WebSocket], None], + router_websocket_class: Union[Type[WebSocket], None], + app_websocket_class: Union[Type[WebSocket], None], + has_default_app_class: bool, + expected: Type[WebSocket], +) -> None: + class MyController(Controller): + @websocket_listener("/") + def handler(self, data: str) -> None: + return + + if controller_websocket_class: + MyController.websocket_class = ControllerWebSocket + + router = Router(path="/", route_handlers=[MyController]) + + if router_websocket_class: + router.websocket_class = router_websocket_class + + app = Litestar(route_handlers=[router]) + + if app_websocket_class or not has_default_app_class: + app.websocket_class = app_websocket_class # type: ignore + + route_handler = app.routes[0].route_handler # type: ignore + + if handler_websocket_class: + route_handler.websocket_class = handler_websocket_class # type: ignore + + websocket_class = route_handler.resolve_websocket_class() # type: ignore + assert websocket_class is expected + + +@pytest.mark.parametrize( + "handler_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected", + ( + (HandlerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket), + (None, RouterWebSocket, AppWebSocket, True, RouterWebSocket), + (None, None, AppWebSocket, True, AppWebSocket), + (None, None, None, True, WebSocket), + (None, None, None, False, WebSocket), + ), + ids=( + "Custom class for all layers", + "Custom class for all above handler layer", + "Custom class for all above router layer", + "No custom class for layers", + "No default class in app", + ), +) +def test_listener_websocket_class_resolution_of_layers( + handler_websocket_class: Union[Type[WebSocket], None], + router_websocket_class: Union[Type[WebSocket], None], + app_websocket_class: Union[Type[WebSocket], None], + has_default_app_class: bool, + expected: Type[WebSocket], +) -> None: + class Handler(WebsocketListener): + path = "/" + websocket_class = handler_websocket_class + + def on_receive(self, data: str) -> str: # pyright: ignore + return data + + router = Router(path="/", route_handlers=[Handler]) + + if router_websocket_class: + router.websocket_class = router_websocket_class + + app = Litestar(route_handlers=[router]) + + if app_websocket_class or not has_default_app_class: + app.websocket_class = app_websocket_class # type: ignore + + route_handler = app.routes[0].route_handler # type: ignore + + if handler_websocket_class: + route_handler.websocket_class = handler_websocket_class # type: ignore + + websocket_class = route_handler.resolve_websocket_class() # type: ignore + assert websocket_class is expected