Skip to content

Commit

Permalink
Filter out messages (#83)
Browse files Browse the repository at this point in the history
* Filter out messages

* Add test

* Fix tests
  • Loading branch information
davidbrochart committed Oct 1, 2021
1 parent 77675b3 commit 9f557ef
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
pip install ./plugins/kernels
pip install ./plugins/lab
pip install ./plugins/jupyterlab
pip install flake8 black mypy pytest requests
pip install flake8 black mypy pytest pytest-asyncio requests ipykernel
- name: Check style
run: |
Expand Down
46 changes: 41 additions & 5 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import signal
from datetime import datetime
from typing import Optional, List, Dict, cast
from typing import Iterable, Optional, List, Dict, cast

from fastapi import WebSocket, WebSocketDisconnect # type: ignore

Expand Down Expand Up @@ -43,6 +43,29 @@ def __init__(
self.key = cast(str, self.connection_cfg["key"])
self.channel_tasks: List[asyncio.Task] = []
self.sessions: Dict[str, WebSocket] = {}
# blocked messages and allowed messages are mutually exclusive
self.blocked_messages: List[str] = []
self.allowed_messages: Optional[
List[str]
] = None # when None, all messages are allowed
# when [], no message is allowed

def block_messages(self, message_types: Iterable[str] = []):
# if using blocked messages, discard allowed messages
self.allowed_messages = None
if isinstance(message_types, str):
message_types = [message_types]
self.blocked_messages = list(message_types)

def allow_messages(self, message_types: Optional[Iterable[str]] = None):
# if using allowed messages, discard blocked messages
self.blocked_messages = []
if message_types is None:
self.allowed_messages = None
return
if isinstance(message_types, str):
message_types = [message_types]
self.allowed_messages = list(message_types)

@property
def connections(self) -> int:
Expand Down Expand Up @@ -71,10 +94,17 @@ async def start(self) -> None:
]

async def stop(self) -> None:
self.kernel_process.send_signal(signal.SIGINT)
self.kernel_process.kill()
await self.kernel_process.wait()
os.remove(self.connection_file_path)
# FIXME: stop kernel in a better way
try:
self.kernel_process.send_signal(signal.SIGINT)
self.kernel_process.kill()
await self.kernel_process.wait()
except Exception:
pass
try:
os.remove(self.connection_file_path)
except Exception:
pass
for task in self.channel_tasks:
task.cancel()
self.channel_tasks = []
Expand Down Expand Up @@ -110,6 +140,12 @@ async def listen_web(self, websocket: WebSocket):
try:
while True:
msg = await websocket.receive_json()
msg_type = msg["header"]["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
channel = msg["channel"]
msg = {
"header": msg["header"],
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ test =
black
mypy
pytest
pytest-asyncio
requests
ipykernel

[options.entry_points]
console_scripts =
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
import socket
import subprocess
import time
import asyncio

import pytest

pytest_plugins = ("fps.testing.fixtures",)


@pytest.fixture(scope="session")
def event_loop():
"""Change event_loop fixture to module level."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture()
def authenticated_user(client):
username = uuid4().hex
Expand Down
10 changes: 9 additions & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
)


def test_kernel_channels(client, authenticated_user):
def test_kernel_channels_unauthenticated(client):
with pytest.raises(KeyError):
with client.websocket_connect(
"/api/kernels/kernel_id_0/channels?session_id=session_id_0",
):
pass


def test_kernel_channels_authenticated(client, authenticated_user):
with client.websocket_connect(
"/api/kernels/kernel_id_0/channels?session_id=session_id_0",
cookies=client.cookies,
Expand Down
75 changes: 75 additions & 0 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
from time import sleep

import pytest
from fps_kernels.kernel_server.server import kernels, KernelServer


@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth",))
async def test_kernel_messages(auth_mode, client, capfd):
kernel_id = "kernel_id_0"
kernel_name = "python3"
kernelspec_path = (
os.environ["CONDA_PREFIX"] + f"/share/jupyter/kernels/{kernel_name}/kernel.json"
)
kernel_server = KernelServer(
kernelspec_path=kernelspec_path, capture_kernel_output=False
)
await kernel_server.start()
kernels[kernel_id] = {"server": kernel_server}
msg_id = "0"
msg = {
"channel": "shell",
"parent_header": None,
"content": None,
"metadata": None,
"header": {
"msg_type": "msg_type_0",
"msg_id": msg_id,
},
}

# block msg_type_0
msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1)
kernel_server.block_messages("msg_type_0")
with client.websocket_connect(
f"/api/kernels/{kernel_id}/channels?session_id=session_id_0",
) as websocket:
websocket.send_json(msg)
sleep(0.1)
out, err = capfd.readouterr()
assert not err

# allow only msg_type_0
msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1)
kernel_server.allow_messages("msg_type_0")
with client.websocket_connect(
f"/api/kernels/{kernel_id}/channels?session_id=session_id_0",
) as websocket:
websocket.send_json(msg)
sleep(0.1)
out, err = capfd.readouterr()
assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1

# block all messages
msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1)
kernel_server.allow_messages([])
with client.websocket_connect(
f"/api/kernels/{kernel_id}/channels?session_id=session_id_0",
) as websocket:
websocket.send_json(msg)
sleep(0.1)
out, err = capfd.readouterr()
assert not err

# allow all messages
msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1)
kernel_server.allow_messages()
with client.websocket_connect(
f"/api/kernels/{kernel_id}/channels?session_id=session_id_0",
) as websocket:
websocket.send_json(msg)
sleep(0.1)
out, err = capfd.readouterr()
assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1

0 comments on commit 9f557ef

Please sign in to comment.