Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow KernelServer to connect to existing kernel #80

Merged
merged 1 commit into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions plugins/kernels/fps_kernels/kernel_server/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def write_connection_file(
return fname, cfg


def read_connection_file(fname: str) -> cfg_t:
with open(fname, "rt") as f:
cfg: cfg_t = json.load(f)

return cfg


async def launch_kernel(
kernelspec_path: str, connection_file_path: str, capture_output: bool
) -> asyncio.subprocess.Process:
Expand Down
34 changes: 26 additions & 8 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,44 @@
import asyncio
import signal
from datetime import datetime
from typing import List, Dict, cast
from typing import Optional, List, Dict, cast

from fastapi import WebSocket, WebSocketDisconnect # type: ignore

from .connect import write_connection_file, launch_kernel, connect_channel # type: ignore
from .connect import (
write_connection_file as _write_connection_file,
read_connection_file,
launch_kernel,
connect_channel,
cfg_t,
) # type: ignore
from .message import receive_message, send_message, create_message # type: ignore


kernels: dict = {}


class KernelServer:
def __init__(
self,
kernelspec_path: str = "",
connection_cfg: Optional[cfg_t] = None,
connection_file: str = "",
write_connection_file: bool = True,
capture_kernel_output: bool = True,
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path
if not self.kernelspec_path:
raise RuntimeError(
"Could not find a kernel, maybe you forgot to install one?"
if write_connection_file:
self.connection_file_path, self.connection_cfg = _write_connection_file(
connection_file
)
self.connection_file_path, self.connection_cfg = write_connection_file(
connection_file
)
elif connection_file:
self.connection_file_path = connection_file
self.connection_cfg = read_connection_file(connection_file)
else:
assert connection_cfg is not None
self.connection_cfg = connection_cfg
self.key = cast(str, self.connection_cfg["key"])
self.channel_tasks: List[asyncio.Task] = []
self.sessions: Dict[str, WebSocket] = {}
Expand All @@ -35,6 +49,10 @@ def connections(self) -> int:
return len(self.sessions)

async def start(self) -> None:
if not self.kernelspec_path:
raise RuntimeError(
"Could not find a kernel, maybe you forgot to install one?"
)
self.last_activity = {
"date": datetime.utcnow().isoformat() + "Z",
"execution_state": "starting",
Expand Down
20 changes: 18 additions & 2 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .kernel_server.server import KernelServer # type: ignore
from .kernel_server.server import KernelServer, kernels # type: ignore
from .models import Session

router = APIRouter()

kernelspecs: dict = {}
sessions: dict = {}
kernels: dict = {}
prefix_dir: pathlib.Path = pathlib.Path(sys.prefix)


Expand Down Expand Up @@ -165,6 +164,23 @@ async def restart_kernel(
return result


@router.get("/api/kernels/{kernel_id}")
async def get_kernel(
kernel_id,
user: User = Depends(current_user()),
):
if kernel_id in kernels:
kernel = kernels[kernel_id]
result = {
"id": kernel_id,
"name": kernel["name"],
"connections": kernel["server"].connections,
"last_activity": kernel["server"].last_activity["date"],
"execution_state": kernel["server"].last_activity["execution_state"],
}
return result


@router.websocket("/api/kernels/{kernel_id}/channels")
async def kernel_channels(
websocket: WebSocket,
Expand Down