diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9a29ee14..b0e42fb6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -43,7 +43,7 @@ jobs: pip install ./plugins/kernels pip install ./plugins/lab pip install ./plugins/jupyterlab - pip install flake8 black mypy pytest requests + pip install flake8 black mypy pytest pytest-asyncio requests ipykernel - name: Check style run: | diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index 949f957d..ee1d6fae 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -2,7 +2,7 @@ import asyncio import signal from datetime import datetime -from typing import Optional, List, Dict, cast +from typing import Iterable, Optional, List, Dict, cast from fastapi import WebSocket, WebSocketDisconnect # type: ignore @@ -43,6 +43,29 @@ def __init__( self.key = cast(str, self.connection_cfg["key"]) self.channel_tasks: List[asyncio.Task] = [] self.sessions: Dict[str, WebSocket] = {} + # blocked messages and allowed messages are mutually exclusive + self.blocked_messages: List[str] = [] + self.allowed_messages: Optional[ + List[str] + ] = None # when None, all messages are allowed + # when [], no message is allowed + + def block_messages(self, message_types: Iterable[str] = []): + # if using blocked messages, discard allowed messages + self.allowed_messages = None + if isinstance(message_types, str): + message_types = [message_types] + self.blocked_messages = list(message_types) + + def allow_messages(self, message_types: Optional[Iterable[str]] = None): + # if using allowed messages, discard blocked messages + self.blocked_messages = [] + if message_types is None: + self.allowed_messages = None + return + if isinstance(message_types, str): + message_types = [message_types] + self.allowed_messages = list(message_types) @property def connections(self) -> int: @@ -71,10 +94,17 @@ async def start(self) -> None: ] async def stop(self) -> None: - self.kernel_process.send_signal(signal.SIGINT) - self.kernel_process.kill() - await self.kernel_process.wait() - os.remove(self.connection_file_path) + # FIXME: stop kernel in a better way + try: + self.kernel_process.send_signal(signal.SIGINT) + self.kernel_process.kill() + await self.kernel_process.wait() + except Exception: + pass + try: + os.remove(self.connection_file_path) + except Exception: + pass for task in self.channel_tasks: task.cancel() self.channel_tasks = [] @@ -110,6 +140,12 @@ async def listen_web(self, websocket: WebSocket): try: while True: msg = await websocket.receive_json() + msg_type = msg["header"]["msg_type"] + if (msg_type in self.blocked_messages) or ( + self.allowed_messages is not None + and msg_type not in self.allowed_messages + ): + continue channel = msg["channel"] msg = { "header": msg["header"], diff --git a/setup.cfg b/setup.cfg index dbd9ca39..08d549ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,9 @@ test = black mypy pytest + pytest-asyncio requests + ipykernel [options.entry_points] console_scripts = diff --git a/tests/conftest.py b/tests/conftest.py index 373fcd1c..1c6f9162 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,21 @@ import socket import subprocess import time +import asyncio import pytest pytest_plugins = ("fps.testing.fixtures",) +@pytest.fixture(scope="session") +def event_loop(): + """Change event_loop fixture to module level.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + @pytest.fixture() def authenticated_user(client): username = uuid4().hex diff --git a/tests/test_auth.py b/tests/test_auth.py index ac70cf64..63c911cc 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -8,7 +8,15 @@ ) -def test_kernel_channels(client, authenticated_user): +def test_kernel_channels_unauthenticated(client): + with pytest.raises(KeyError): + with client.websocket_connect( + "/api/kernels/kernel_id_0/channels?session_id=session_id_0", + ): + pass + + +def test_kernel_channels_authenticated(client, authenticated_user): with client.websocket_connect( "/api/kernels/kernel_id_0/channels?session_id=session_id_0", cookies=client.cookies, diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100644 index 00000000..ab6f998e --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,75 @@ +import os +from time import sleep + +import pytest +from fps_kernels.kernel_server.server import kernels, KernelServer + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_mode", ("noauth",)) +async def test_kernel_messages(auth_mode, client, capfd): + kernel_id = "kernel_id_0" + kernel_name = "python3" + kernelspec_path = ( + os.environ["CONDA_PREFIX"] + f"/share/jupyter/kernels/{kernel_name}/kernel.json" + ) + kernel_server = KernelServer( + kernelspec_path=kernelspec_path, capture_kernel_output=False + ) + await kernel_server.start() + kernels[kernel_id] = {"server": kernel_server} + msg_id = "0" + msg = { + "channel": "shell", + "parent_header": None, + "content": None, + "metadata": None, + "header": { + "msg_type": "msg_type_0", + "msg_id": msg_id, + }, + } + + # block msg_type_0 + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.block_messages("msg_type_0") + with client.websocket_connect( + f"/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + websocket.send_json(msg) + sleep(0.1) + out, err = capfd.readouterr() + assert not err + + # allow only msg_type_0 + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages("msg_type_0") + with client.websocket_connect( + f"/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + websocket.send_json(msg) + sleep(0.1) + out, err = capfd.readouterr() + assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1 + + # block all messages + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages([]) + with client.websocket_connect( + f"/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + websocket.send_json(msg) + sleep(0.1) + out, err = capfd.readouterr() + assert not err + + # allow all messages + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages() + with client.websocket_connect( + f"/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + websocket.send_json(msg) + sleep(0.1) + out, err = capfd.readouterr() + assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1