Skip to content

Commit

Permalink
Support websocket subprotocols
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 18, 2022
1 parent 3262df7 commit dabf937
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 59 deletions.
21 changes: 20 additions & 1 deletion plugins/kernels/fps_kernels/kernel_server/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import json
import asyncio
import uuid
from typing import Dict, Tuple, Union
from typing import Dict, Tuple, Union, Optional

import zmq
import zmq.asyncio
from zmq.sugar.socket import Socket

from fastapi import WebSocket


channel_socket_types = {
"hb": zmq.REQ,
Expand Down Expand Up @@ -101,3 +103,20 @@ def connect_channel(channel_name: str, cfg: cfg_t) -> Socket:
if channel_name == "iopub":
sock.setsockopt(zmq.SUBSCRIBE, b"")
return sock


class AcceptedWebSocket:
_websocket: WebSocket
_accepted_subprotocol: Optional[str]

def __init__(self, websocket, accepted_subprotocol):
self._websocket = websocket
self._accepted_subprotocol = accepted_subprotocol

@property
def websocket(self):
return self._websocket

@property
def accepted_subprotocol(self):
return self._accepted_subprotocol
39 changes: 28 additions & 11 deletions plugins/kernels/fps_kernels/kernel_server/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ def serialize(msg: Dict[str, Any], key: str) -> List[bytes]:
return to_send


def deserialize(msg_list: List[bytes]) -> Dict[str, Any]:
def deserialize(
msg_list: List[bytes], parent_header: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
message: Dict[str, Any] = {}
header = unpack(msg_list[1])
message["header"] = header
message["msg_id"] = header["msg_id"]
message["msg_type"] = header["msg_type"]
message["parent_header"] = unpack(msg_list[2])
if parent_header:
message["parent_header"] = parent_header
else:
message["parent_header"] = unpack(msg_list[2])
message["metadata"] = unpack(msg_list[3])
message["content"] = unpack(msg_list[4])
message["buffers"] = [memoryview(b) for b in msg_list[5:]]
Expand Down Expand Up @@ -104,8 +109,8 @@ def send_raw_message(parts: List[bytes], sock: Socket, key: str) -> None:

def get_channel_parts(msg: bytes) -> Tuple[str, List[bytes]]:
layout_len = int.from_bytes(msg[:2], "little")
layout = json.loads(msg[2:2 + layout_len])
parts: List[bytes] = list(get_parts(msg[2 + layout_len:], layout["offsets"]))
layout = json.loads(msg[2 : 2 + layout_len])
parts: List[bytes] = list(get_parts(msg[2 + layout_len :], layout["offsets"]))
return layout["channel"], parts


Expand All @@ -129,26 +134,38 @@ async def receive_message(
return None


def get_bin_msg(channel: str, parts: List[bytes]) -> List[bytes]:
async def get_zmq_parts(socket: Socket) -> List[bytes]:
parts = await socket.recv_multipart()
idents, parts = feed_identities(parts)
return parts


def get_msg_from_parts(
parts: List[bytes], parent_header: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return deserialize(parts, parent_header=parent_header)


def get_bin_msg_from_parts(channel: str, parts: List[bytes]) -> List[bytes]:
offsets = []
curr_sum = 0
for part in parts[1:]:
length = len(part)
offsets.append(length + curr_sum)
curr_sum += length
layout = json.dumps({
"channel": channel,
"offsets": offsets,
}).encode("utf-8")
layout = json.dumps(
{
"channel": channel,
"offsets": offsets,
}
).encode("utf-8")
layout_length = len(layout).to_bytes(2, byteorder="little")
bin_msg = [layout_length, layout] + parts[1:]
return bin_msg


def get_parent_header(parts: List[bytes]) -> Dict[str, Any]:
idents, msg_list = feed_identities(parts)
return unpack(msg_list[2])
return unpack(parts[2])


def utcnow() -> datetime:
Expand Down
114 changes: 70 additions & 44 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime
from typing import Iterable, Optional, List, Dict, cast

from fastapi import WebSocket, WebSocketDisconnect # type: ignore
from fastapi import WebSocketDisconnect # type: ignore
from starlette.websockets import WebSocketState

from .connect import (
Expand All @@ -14,6 +14,7 @@
launch_kernel,
connect_channel,
cfg_t,
AcceptedWebSocket,
) # type: ignore
from .message import (
receive_message,
Expand All @@ -24,7 +25,9 @@
from_binary,
get_channel_parts,
get_parent_header,
get_bin_msg,
get_zmq_parts,
get_bin_msg_from_parts,
get_msg_from_parts,
) # type: ignore


Expand All @@ -46,7 +49,7 @@ def __init__(
self.connection_file = connection_file
self.write_connection_file = write_connection_file
self.channel_tasks: List[asyncio.Task] = []
self.sessions: Dict[str, WebSocket] = {}
self.sessions: Dict[str, AcceptedWebSocket] = {}
# blocked messages and allowed messages are mutually exclusive
self.blocked_messages: List[str] = []
self.allowed_messages: Optional[
Expand Down Expand Up @@ -107,9 +110,9 @@ async def start(self) -> None:
self.iopub_channel = connect_channel("iopub", self.connection_cfg)
await self._wait_for_ready()
self.channel_tasks += [
asyncio.create_task(self.listen_shell()),
asyncio.create_task(self.listen_control()),
asyncio.create_task(self.listen_iopub()),
asyncio.create_task(self.listen("shell")),
asyncio.create_task(self.listen("control")),
asyncio.create_task(self.listen("iopub")),
]

async def stop(self) -> None:
Expand All @@ -133,55 +136,78 @@ async def restart(self) -> None:
self.setup_connection_file()
await self.start()

async def serve(self, websocket: WebSocket, session_id: str):
async def serve(self, websocket: AcceptedWebSocket, session_id: str):
self.sessions[session_id] = websocket
await self.listen_web(websocket)
del self.sessions[session_id]

async def listen_web(self, websocket: WebSocket):
async def listen_web(self, websocket: AcceptedWebSocket):
try:
while True:
msg = await websocket.receive_bytes()
# FIXME: add back message filtering
channel, parts = get_channel_parts(msg)
if channel == "shell":
send_raw_message(parts, self.shell_channel, self.key)
elif channel == "control":
send_raw_message(parts, self.control_channel, self.key)
if not websocket.accepted_subprotocol:
while True:
msg = await receive_json_or_bytes(websocket.websocket)
msg_type = msg["header"]["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
channel = msg.pop("channel")
if channel == "shell":
send_message(msg, self.shell_channel, self.key)
elif channel == "control":
send_message(msg, self.control_channel, self.key)
elif websocket.accepted_subprotocol == "0.0.1":
while True:
msg = await websocket.websocket.receive_bytes()
# FIXME: add back message filtering
channel, parts = get_channel_parts(msg)
if channel == "shell":
send_raw_message(parts, self.shell_channel, self.key)
elif channel == "control":
send_raw_message(parts, self.control_channel, self.key)
except WebSocketDisconnect:
pass

async def listen_shell(self):
while True:
parts = await self.shell_channel.recv_multipart()
parent_header = get_parent_header(parts)
session = parent_header["session"]
if session in self.sessions:
websocket = self.sessions[session]
bin_msg = get_bin_msg("shell", parts)
await websocket.send_bytes(bin_msg)
async def listen(self, channel_name: str):
if channel_name == "shell":
channel = self.shell_channel
elif channel_name == "control":
channel = self.control_channel
elif channel_name == "iopub":
channel = self.iopub_channel

async def listen_control(self):
while True:
parts = await self.control_channel.recv_multipart()
parts = await get_zmq_parts(channel)
parent_header = get_parent_header(parts)
session = parent_header["session"]
if session in self.sessions:
websocket = self.sessions[session]
bin_msg = get_bin_msg("control", parts)
await websocket.send_bytes(bin_msg)

async def listen_iopub(self):
while True:
parts = await self.iopub_channel.recv_multipart()
bin_msg = get_bin_msg("iopub", parts)
for websocket in self.sessions.values():
try:
await websocket.send_bytes(bin_msg)
except Exception:
pass
# FIXME: add back last_activity update
# or replace it with control channel retrieving
if channel == self.iopub_channel:
# broadcast to all web clients
for websocket in self.sessions.values():
if not websocket.accepted_subprotocol:
# default, "legacy" protocol
msg = get_msg_from_parts(parts, parent_header=parent_header)
msg["channel"] = channel_name
await send_json_or_bytes(websocket.websocket, msg)
elif websocket.accepted_subprotocol == "0.0.1":
bin_msg = get_bin_msg_from_parts(channel_name, parts)
try:
await websocket.websocket.send_bytes(bin_msg)
except Exception:
pass
# FIXME: add back last_activity update
# or should we request it from the control channel?
else:
session = parent_header["session"]
if session in self.sessions:
websocket = self.sessions[session]
if not websocket.accepted_subprotocol:
# default, "legacy" protocol
msg = get_msg_from_parts(parts, parent_header=parent_header)
msg["channel"] = channel_name
await send_json_or_bytes(websocket.websocket, msg)
elif websocket.accepted_subprotocol == "0.0.1":
bin_msg = get_bin_msg_from_parts(channel_name, parts)
await websocket.websocket.send_bytes(bin_msg)

async def _wait_for_ready(self):
while True:
Expand Down
8 changes: 5 additions & 3 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fps_auth.config import get_auth_config # type: ignore
from fps_lab.config import get_lab_config # type: ignore

from .kernel_server.server import KernelServer, kernels # type: ignore
from .kernel_server.server import AcceptedWebSocket, KernelServer, kernels # type: ignore
from .models import Session

router = APIRouter()
Expand Down Expand Up @@ -202,10 +202,12 @@ async def kernel_channels(
if user:
accept_websocket = True
if accept_websocket:
await websocket.accept()
subprotocol = "0.0.1" if "0.0.1" in websocket["subprotocols"] else None
await websocket.accept(subprotocol=subprotocol)
accepted_websocket = AcceptedWebSocket(websocket, subprotocol)
if kernel_id in kernels:
kernel_server = kernels[kernel_id]["server"]
await kernel_server.serve(websocket, session_id)
await kernel_server.serve(accepted_websocket, session_id)
else:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)

Expand Down

0 comments on commit dabf937

Please sign in to comment.