From dabf93753aa4851e28ac2c43f459c83e3cd55021 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Jan 2022 17:59:01 +0100 Subject: [PATCH] Support websocket subprotocols --- .../fps_kernels/kernel_server/connect.py | 21 +++- .../fps_kernels/kernel_server/message.py | 39 ++++-- .../fps_kernels/kernel_server/server.py | 114 +++++++++++------- plugins/kernels/fps_kernels/routes.py | 8 +- 4 files changed, 123 insertions(+), 59 deletions(-) diff --git a/plugins/kernels/fps_kernels/kernel_server/connect.py b/plugins/kernels/fps_kernels/kernel_server/connect.py index 53325979..d6f04e9b 100644 --- a/plugins/kernels/fps_kernels/kernel_server/connect.py +++ b/plugins/kernels/fps_kernels/kernel_server/connect.py @@ -4,12 +4,14 @@ import json import asyncio import uuid -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, Optional import zmq import zmq.asyncio from zmq.sugar.socket import Socket +from fastapi import WebSocket + channel_socket_types = { "hb": zmq.REQ, @@ -101,3 +103,20 @@ def connect_channel(channel_name: str, cfg: cfg_t) -> Socket: if channel_name == "iopub": sock.setsockopt(zmq.SUBSCRIBE, b"") return sock + + +class AcceptedWebSocket: + _websocket: WebSocket + _accepted_subprotocol: Optional[str] + + def __init__(self, websocket, accepted_subprotocol): + self._websocket = websocket + self._accepted_subprotocol = accepted_subprotocol + + @property + def websocket(self): + return self._websocket + + @property + def accepted_subprotocol(self): + return self._accepted_subprotocol diff --git a/plugins/kernels/fps_kernels/kernel_server/message.py b/plugins/kernels/fps_kernels/kernel_server/message.py index f40b0a29..e07cf84f 100644 --- a/plugins/kernels/fps_kernels/kernel_server/message.py +++ b/plugins/kernels/fps_kernels/kernel_server/message.py @@ -69,13 +69,18 @@ def serialize(msg: Dict[str, Any], key: str) -> List[bytes]: return to_send -def deserialize(msg_list: List[bytes]) -> Dict[str, Any]: +def deserialize( + msg_list: List[bytes], parent_header: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: message: Dict[str, Any] = {} header = unpack(msg_list[1]) message["header"] = header message["msg_id"] = header["msg_id"] message["msg_type"] = header["msg_type"] - message["parent_header"] = unpack(msg_list[2]) + if parent_header: + message["parent_header"] = parent_header + else: + message["parent_header"] = unpack(msg_list[2]) message["metadata"] = unpack(msg_list[3]) message["content"] = unpack(msg_list[4]) message["buffers"] = [memoryview(b) for b in msg_list[5:]] @@ -104,8 +109,8 @@ def send_raw_message(parts: List[bytes], sock: Socket, key: str) -> None: def get_channel_parts(msg: bytes) -> Tuple[str, List[bytes]]: layout_len = int.from_bytes(msg[:2], "little") - layout = json.loads(msg[2:2 + layout_len]) - parts: List[bytes] = list(get_parts(msg[2 + layout_len:], layout["offsets"])) + layout = json.loads(msg[2 : 2 + layout_len]) + parts: List[bytes] = list(get_parts(msg[2 + layout_len :], layout["offsets"])) return layout["channel"], parts @@ -129,26 +134,38 @@ async def receive_message( return None -def get_bin_msg(channel: str, parts: List[bytes]) -> List[bytes]: +async def get_zmq_parts(socket: Socket) -> List[bytes]: + parts = await socket.recv_multipart() idents, parts = feed_identities(parts) + return parts + + +def get_msg_from_parts( + parts: List[bytes], parent_header: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + return deserialize(parts, parent_header=parent_header) + + +def get_bin_msg_from_parts(channel: str, parts: List[bytes]) -> List[bytes]: offsets = [] curr_sum = 0 for part in parts[1:]: length = len(part) offsets.append(length + curr_sum) curr_sum += length - layout = json.dumps({ - "channel": channel, - "offsets": offsets, - }).encode("utf-8") + layout = json.dumps( + { + "channel": channel, + "offsets": offsets, + } + ).encode("utf-8") layout_length = len(layout).to_bytes(2, byteorder="little") bin_msg = [layout_length, layout] + parts[1:] return bin_msg def get_parent_header(parts: List[bytes]) -> Dict[str, Any]: - idents, msg_list = feed_identities(parts) - return unpack(msg_list[2]) + return unpack(parts[2]) def utcnow() -> datetime: diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index e01d8f27..656bd6d0 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Iterable, Optional, List, Dict, cast -from fastapi import WebSocket, WebSocketDisconnect # type: ignore +from fastapi import WebSocketDisconnect # type: ignore from starlette.websockets import WebSocketState from .connect import ( @@ -14,6 +14,7 @@ launch_kernel, connect_channel, cfg_t, + AcceptedWebSocket, ) # type: ignore from .message import ( receive_message, @@ -24,7 +25,9 @@ from_binary, get_channel_parts, get_parent_header, - get_bin_msg, + get_zmq_parts, + get_bin_msg_from_parts, + get_msg_from_parts, ) # type: ignore @@ -46,7 +49,7 @@ def __init__( self.connection_file = connection_file self.write_connection_file = write_connection_file self.channel_tasks: List[asyncio.Task] = [] - self.sessions: Dict[str, WebSocket] = {} + self.sessions: Dict[str, AcceptedWebSocket] = {} # blocked messages and allowed messages are mutually exclusive self.blocked_messages: List[str] = [] self.allowed_messages: Optional[ @@ -107,9 +110,9 @@ async def start(self) -> None: self.iopub_channel = connect_channel("iopub", self.connection_cfg) await self._wait_for_ready() self.channel_tasks += [ - asyncio.create_task(self.listen_shell()), - asyncio.create_task(self.listen_control()), - asyncio.create_task(self.listen_iopub()), + asyncio.create_task(self.listen("shell")), + asyncio.create_task(self.listen("control")), + asyncio.create_task(self.listen("iopub")), ] async def stop(self) -> None: @@ -133,55 +136,78 @@ async def restart(self) -> None: self.setup_connection_file() await self.start() - async def serve(self, websocket: WebSocket, session_id: str): + async def serve(self, websocket: AcceptedWebSocket, session_id: str): self.sessions[session_id] = websocket await self.listen_web(websocket) del self.sessions[session_id] - async def listen_web(self, websocket: WebSocket): + async def listen_web(self, websocket: AcceptedWebSocket): try: - while True: - msg = await websocket.receive_bytes() - # FIXME: add back message filtering - channel, parts = get_channel_parts(msg) - if channel == "shell": - send_raw_message(parts, self.shell_channel, self.key) - elif channel == "control": - send_raw_message(parts, self.control_channel, self.key) + if not websocket.accepted_subprotocol: + while True: + msg = await receive_json_or_bytes(websocket.websocket) + msg_type = msg["header"]["msg_type"] + if (msg_type in self.blocked_messages) or ( + self.allowed_messages is not None + and msg_type not in self.allowed_messages + ): + continue + channel = msg.pop("channel") + if channel == "shell": + send_message(msg, self.shell_channel, self.key) + elif channel == "control": + send_message(msg, self.control_channel, self.key) + elif websocket.accepted_subprotocol == "0.0.1": + while True: + msg = await websocket.websocket.receive_bytes() + # FIXME: add back message filtering + channel, parts = get_channel_parts(msg) + if channel == "shell": + send_raw_message(parts, self.shell_channel, self.key) + elif channel == "control": + send_raw_message(parts, self.control_channel, self.key) except WebSocketDisconnect: pass - async def listen_shell(self): - while True: - parts = await self.shell_channel.recv_multipart() - parent_header = get_parent_header(parts) - session = parent_header["session"] - if session in self.sessions: - websocket = self.sessions[session] - bin_msg = get_bin_msg("shell", parts) - await websocket.send_bytes(bin_msg) + async def listen(self, channel_name: str): + if channel_name == "shell": + channel = self.shell_channel + elif channel_name == "control": + channel = self.control_channel + elif channel_name == "iopub": + channel = self.iopub_channel - async def listen_control(self): while True: - parts = await self.control_channel.recv_multipart() + parts = await get_zmq_parts(channel) parent_header = get_parent_header(parts) - session = parent_header["session"] - if session in self.sessions: - websocket = self.sessions[session] - bin_msg = get_bin_msg("control", parts) - await websocket.send_bytes(bin_msg) - - async def listen_iopub(self): - while True: - parts = await self.iopub_channel.recv_multipart() - bin_msg = get_bin_msg("iopub", parts) - for websocket in self.sessions.values(): - try: - await websocket.send_bytes(bin_msg) - except Exception: - pass - # FIXME: add back last_activity update - # or replace it with control channel retrieving + if channel == self.iopub_channel: + # broadcast to all web clients + for websocket in self.sessions.values(): + if not websocket.accepted_subprotocol: + # default, "legacy" protocol + msg = get_msg_from_parts(parts, parent_header=parent_header) + msg["channel"] = channel_name + await send_json_or_bytes(websocket.websocket, msg) + elif websocket.accepted_subprotocol == "0.0.1": + bin_msg = get_bin_msg_from_parts(channel_name, parts) + try: + await websocket.websocket.send_bytes(bin_msg) + except Exception: + pass + # FIXME: add back last_activity update + # or should we request it from the control channel? + else: + session = parent_header["session"] + if session in self.sessions: + websocket = self.sessions[session] + if not websocket.accepted_subprotocol: + # default, "legacy" protocol + msg = get_msg_from_parts(parts, parent_header=parent_header) + msg["channel"] = channel_name + await send_json_or_bytes(websocket.websocket, msg) + elif websocket.accepted_subprotocol == "0.0.1": + bin_msg = get_bin_msg_from_parts(channel_name, parts) + await websocket.websocket.send_bytes(bin_msg) async def _wait_for_ready(self): while True: diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index 808ea27d..bf742e6c 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -15,7 +15,7 @@ from fps_auth.config import get_auth_config # type: ignore from fps_lab.config import get_lab_config # type: ignore -from .kernel_server.server import KernelServer, kernels # type: ignore +from .kernel_server.server import AcceptedWebSocket, KernelServer, kernels # type: ignore from .models import Session router = APIRouter() @@ -202,10 +202,12 @@ async def kernel_channels( if user: accept_websocket = True if accept_websocket: - await websocket.accept() + subprotocol = "0.0.1" if "0.0.1" in websocket["subprotocols"] else None + await websocket.accept(subprotocol=subprotocol) + accepted_websocket = AcceptedWebSocket(websocket, subprotocol) if kernel_id in kernels: kernel_server = kernels[kernel_id]["server"] - await kernel_server.serve(websocket, session_id) + await kernel_server.serve(accepted_websocket, session_id) else: await websocket.close(code=status.WS_1008_POLICY_VIOLATION)