Skip to content

Commit

Permalink
feat: expose websocket_class to the other layers
Browse files Browse the repository at this point in the history
  • Loading branch information
kedod authored and provinzkraut committed Mar 2, 2024
1 parent 7d2335c commit b7b6582
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 2 deletions.
19 changes: 19 additions & 0 deletions docs/examples/websockets/custom_websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from litestar import Litestar, WebSocket, websocket_listener
from litestar.types.asgi_types import WebSocketMode


class CustomWebSocket(WebSocket):
async def receive_data(self, mode: WebSocketMode) -> str | bytes:
"""Return fixed response for every websocket message."""
await super().receive_data(mode=mode)
return "Fixed response"


@websocket_listener("/")
async def handler(data: str) -> str:
return data


app = Litestar([handler], websocket_class=CustomWebSocket)
22 changes: 22 additions & 0 deletions docs/usage/websockets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,25 @@ encapsulate more complex logic.

.. literalinclude:: /examples/websockets/listener_class_based_async.py
:language: python


Custom WebSocket
----------------

.. versionadded:: 2.7.0

Litestar supports custom ``websocket_class`` instances, which can be used to further configure the default :class:`WebSocket`.
The example below illustrates how to implement custom websocket class for the whole application.

.. dropdown:: Example of a custom websocket at the application level

.. literalinclude:: /examples/websockets/custom_websocket.py
:language: python

.. admonition:: Layered architecture

WebSocket classes are part of Litestar's layered architecture, which means you can
set a websocket class on every layer of the application. If you have set a webscoket
class on multiple layers, the layer closest to the route handler will take precedence.

You can read more about this in the :ref:`usage/applications:layered architecture` section
3 changes: 2 additions & 1 deletion litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __init__(
self.state = config.state
self._static_files_config = config.static_files_config
self.template_engine = config.template_config.engine_instance if config.template_config else None
self.websocket_class = config.websocket_class or WebSocket
self.websocket_class: type[WebSocket] = config.websocket_class or WebSocket
self.debug = config.debug
self.pdb_on_exception: bool = config.pdb_on_exception
self.include_in_schema = include_in_schema
Expand Down Expand Up @@ -462,6 +462,7 @@ def __init__(
type_encoders=config.type_encoders,
type_decoders=config.type_decoders,
include_in_schema=config.include_in_schema,
websocket_class=self.websocket_class,
)

for route_handler in config.route_handlers:
Expand Down
6 changes: 6 additions & 0 deletions litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


if TYPE_CHECKING:
from litestar.connection import WebSocket
from litestar.datastructures import CacheControlHeader, ETag
from litestar.dto import AbstractDTO
from litestar.openapi.spec import SecurityRequirement
Expand Down Expand Up @@ -70,6 +71,7 @@ class Controller:
"tags",
"type_encoders",
"type_decoders",
"websocket_class",
)

after_request: AfterRequestHookHandler | None
Expand Down Expand Up @@ -154,6 +156,10 @@ class Controller:
"""A mapping of types to callables that transform them into types supported for serialization."""
type_decoders: TypeDecodersSequence | None
"""A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization."""
websocket_class: type[WebSocket] | None
"""A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default websocket for all route
handlers under the controller.
"""

def __init__(self, owner: Router) -> None:
"""Initialize a controller.
Expand Down
14 changes: 14 additions & 0 deletions litestar/handlers/websocket_handlers/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class WebsocketListenerRouteHandler(WebsocketRouteHandler):
"connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept",
"on_accept": "Callback invoked after a WebSocket connection has been accepted",
"on_disconnect": "Callback invoked after a WebSocket connection has been closed",
"weboscket_class": "WebSocket class",
"_connection_lifespan": None,
"_handle_receive": None,
"_handle_send": None,
Expand All @@ -86,6 +87,7 @@ def __init__(
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> None:
...
Expand All @@ -110,6 +112,7 @@ def __init__(
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> None:
...
Expand All @@ -134,6 +137,7 @@ def __init__(
return_dto: type[AbstractDTO] | None | EmptyType = Empty,
signature_namespace: Mapping[str, Any] | None = None,
type_encoders: TypeEncodersMap | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> None:
"""Initialize ``WebsocketRouteHandler``
Expand Down Expand Up @@ -168,6 +172,8 @@ def __init__(
modelling.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
**kwargs: Any additional kwarg - will be set in the opt dictionary.
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
"""
if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]):
raise ImproperlyConfiguredException(
Expand All @@ -185,6 +191,7 @@ def __init__(
self.on_accept = ensure_async_callable(on_accept) if on_accept else None
self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None
self.type_encoders = type_encoders
self.websocket_class = websocket_class

listener_dependencies = dict(dependencies or {})

Expand All @@ -209,6 +216,7 @@ def __init__(
signature_namespace=signature_namespace,
dto=dto,
return_dto=return_dto,
websocket_class=websocket_class,
**kwargs,
)

Expand Down Expand Up @@ -346,6 +354,11 @@ class WebsocketListener(ABC):
"""
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
"""
websocket_class: type[WebSocket] | None = None
"""
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
"""

def __init__(self, owner: Router) -> None:
"""Initialize a WebsocketListener instance.
Expand All @@ -372,6 +385,7 @@ def to_handler(self) -> WebsocketListenerRouteHandler:
return_dto=self.return_dto,
signature_namespace=self.signature_namespace,
type_encoders=self.type_encoders,
websocket_class=self.websocket_class,
)(self.on_receive)
handler.owner = self._owner
return handler
Expand Down
18 changes: 18 additions & 0 deletions litestar/handlers/websocket_handlers/route_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Any, Mapping

from litestar.connection import WebSocket
from litestar.exceptions import ImproperlyConfiguredException
from litestar.handlers import BaseRouteHandler
from litestar.types.builtin_types import NoneType
Expand All @@ -28,6 +29,7 @@ def __init__(
name: str | None = None,
opt: dict[str, Any] | None = None,
signature_namespace: Mapping[str, Any] | None = None,
websocket_class: type[WebSocket] | None = None,
**kwargs: Any,
) -> None:
"""Initialize ``WebsocketRouteHandler``
Expand All @@ -46,7 +48,10 @@ def __init__(
signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
**kwargs: Any additional kwarg - will be set in the opt dictionary.
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's
default websocket class.
"""
self.websocket_class = websocket_class

super().__init__(
path=path,
Expand All @@ -60,6 +65,19 @@ def __init__(
**kwargs,
)

def resolve_websocket_class(self) -> type[WebSocket]:
"""Return the closest custom WebSocket class in the owner graph or the default Websocket class.
This method is memoized so the computation occurs only once.
Returns:
The default :class:`WebSocket <.connection.WebSocket>` class for the route handler.
"""
return next(
(layer.websocket_class for layer in reversed(self.ownership_layers) if layer.websocket_class is not None),
WebSocket,
)

def _validate_handler_function(self) -> None:
"""Validate the route handler function once it's set by inspecting its return annotations."""
super()._validate_handler_function()
Expand Down
6 changes: 6 additions & 0 deletions litestar/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


if TYPE_CHECKING:
from litestar.connection import WebSocket
from litestar.datastructures import CacheControlHeader, ETag
from litestar.dto import AbstractDTO
from litestar.openapi.spec import SecurityRequirement
Expand Down Expand Up @@ -76,6 +77,7 @@ class Router:
"tags",
"type_encoders",
"type_decoders",
"websocket_class",
)

def __init__(
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
tags: Sequence[str] | None = None,
type_encoders: TypeEncodersMap | None = None,
type_decoders: TypeDecodersSequence | None = None,
websocket_class: type[WebSocket] | None = None,
) -> None:
"""Initialize a ``Router``.
Expand Down Expand Up @@ -156,6 +159,8 @@ def __init__(
application.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.
websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default for
all route handlers, controllers and other routers associated with the router instance.
"""

self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
Expand Down Expand Up @@ -186,6 +191,7 @@ def __init__(
self.registered_route_handler_ids: set[int] = set()
self.type_encoders = dict(type_encoders) if type_encoders is not None else None
self.type_decoders = list(type_decoders) if type_decoders is not None else None
self.websocket_class = websocket_class

for route_handler in route_handlers or []:
self.register(value=route_handler)
Expand Down
5 changes: 4 additions & 1 deletion litestar/routes/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> N
Returns:
None
"""
websocket: WebSocket[Any, Any, Any] = scope["app"].websocket_class(scope=scope, receive=receive, send=send)

if not self.handler_parameter_model: # pragma: no cover
raise ImproperlyConfiguredException("handler parameter model not defined")

websocket: WebSocket[Any, Any, Any] = self.route_handler.resolve_websocket_class()(
scope=scope, receive=receive, send=send
)

if self.route_handler.resolve_guards():
await self.route_handler.authorize_connection(connection=websocket)

Expand Down
115 changes: 115 additions & 0 deletions tests/unit/test_websocket_class_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Type, Union

import pytest

from litestar import Controller, Litestar, Router, WebSocket
from litestar.handlers.websocket_handlers.listener import WebsocketListener, websocket_listener

RouterWebSocket: Type[WebSocket] = type("RouterWebSocket", (WebSocket,), {})
ControllerWebSocket: Type[WebSocket] = type("ControllerWebSocket", (WebSocket,), {})
AppWebSocket: Type[WebSocket] = type("AppWebSocket", (WebSocket,), {})
HandlerWebSocket: Type[WebSocket] = type("HandlerWebSocket", (WebSocket,), {})


@pytest.mark.parametrize(
"handler_websocket_class, controller_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected",
(
(HandlerWebSocket, ControllerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket),
(None, ControllerWebSocket, RouterWebSocket, AppWebSocket, True, ControllerWebSocket),
(None, None, RouterWebSocket, AppWebSocket, True, RouterWebSocket),
(None, None, None, AppWebSocket, True, AppWebSocket),
(None, None, None, None, True, WebSocket),
(None, None, None, None, False, WebSocket),
),
ids=(
"Custom class for all layers",
"Custom class for all above handler layer",
"Custom class for all above controller layer",
"Custom class for all above router layer",
"No custom class for layers",
"No default class in app",
),
)
def test_websocket_class_resolution_of_layers(
handler_websocket_class: Union[Type[WebSocket], None],
controller_websocket_class: Union[Type[WebSocket], None],
router_websocket_class: Union[Type[WebSocket], None],
app_websocket_class: Union[Type[WebSocket], None],
has_default_app_class: bool,
expected: Type[WebSocket],
) -> None:
class MyController(Controller):
@websocket_listener("/")
def handler(self, data: str) -> None:
return

if controller_websocket_class:
MyController.websocket_class = ControllerWebSocket

router = Router(path="/", route_handlers=[MyController])

if router_websocket_class:
router.websocket_class = router_websocket_class

app = Litestar(route_handlers=[router])

if app_websocket_class or not has_default_app_class:
app.websocket_class = app_websocket_class # type: ignore

route_handler = app.routes[0].route_handler # type: ignore

if handler_websocket_class:
route_handler.websocket_class = handler_websocket_class # type: ignore

websocket_class = route_handler.resolve_websocket_class() # type: ignore
assert websocket_class is expected


@pytest.mark.parametrize(
"handler_websocket_class, router_websocket_class, app_websocket_class, has_default_app_class, expected",
(
(HandlerWebSocket, RouterWebSocket, AppWebSocket, True, HandlerWebSocket),
(None, RouterWebSocket, AppWebSocket, True, RouterWebSocket),
(None, None, AppWebSocket, True, AppWebSocket),
(None, None, None, True, WebSocket),
(None, None, None, False, WebSocket),
),
ids=(
"Custom class for all layers",
"Custom class for all above handler layer",
"Custom class for all above router layer",
"No custom class for layers",
"No default class in app",
),
)
def test_listener_websocket_class_resolution_of_layers(
handler_websocket_class: Union[Type[WebSocket], None],
router_websocket_class: Union[Type[WebSocket], None],
app_websocket_class: Union[Type[WebSocket], None],
has_default_app_class: bool,
expected: Type[WebSocket],
) -> None:
class Handler(WebsocketListener):
path = "/"
websocket_class = handler_websocket_class

def on_receive(self, data: str) -> str: # pyright: ignore
return data

router = Router(path="/", route_handlers=[Handler])

if router_websocket_class:
router.websocket_class = router_websocket_class

app = Litestar(route_handlers=[router])

if app_websocket_class or not has_default_app_class:
app.websocket_class = app_websocket_class # type: ignore

route_handler = app.routes[0].route_handler # type: ignore

if handler_websocket_class:
route_handler.websocket_class = handler_websocket_class # type: ignore

websocket_class = route_handler.resolve_websocket_class() # type: ignore
assert websocket_class is expected

0 comments on commit b7b6582

Please sign in to comment.