Skip to content

Commit

Permalink
Manually restart kernel as stop and start
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 14, 2022
1 parent ff27130 commit 1bbdceb
Showing 1 changed file with 21 additions and 31 deletions.
52 changes: 21 additions & 31 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,9 @@ def __init__(
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path
if write_connection_file:
self.connection_file_path, self.connection_cfg = _write_connection_file(
connection_file
)
elif connection_file:
self.connection_file_path = connection_file
self.connection_cfg = read_connection_file(connection_file)
else:
assert connection_cfg is not None
self.connection_cfg = connection_cfg
self.key = cast(str, self.connection_cfg["key"])
self.connection_cfg = connection_cfg
self.connection_file = connection_file
self.write_connection_file = write_connection_file
self.channel_tasks: List[asyncio.Task] = []
self.sessions: Dict[str, WebSocket] = {}
# blocked messages and allowed messages are mutually exclusive
Expand All @@ -57,6 +49,20 @@ def __init__(
List[str]
] = None # when None, all messages are allowed
# when [], no message is allowed
self.setup_connection_file()

def setup_connection_file(self):
if self.write_connection_file:
self.connection_file_path, self.connection_cfg = _write_connection_file(
self.connection_file
)
elif self.connection_file:
self.connection_file_path = self.connection_file
self.connection_cfg = read_connection_file(self.connection_file)
else:
if self.connection_cfg is None:
raise RuntimeError("No connection_cfg")
self.key = cast(str, self.connection_cfg["key"])

def block_messages(self, message_types: Iterable[str] = []):
# if using blocked messages, discard allowed messages
Expand Down Expand Up @@ -91,6 +97,7 @@ async def start(self) -> None:
self.kernel_process = await launch_kernel(
self.kernelspec_path, self.connection_file_path, self.capture_kernel_output
)
assert self.connection_cfg is not None
self.shell_channel = connect_channel("shell", self.connection_cfg)
self.control_channel = connect_channel("control", self.connection_cfg)
self.iopub_channel = connect_channel("iopub", self.connection_cfg)
Expand Down Expand Up @@ -118,26 +125,9 @@ async def stop(self) -> None:
self.channel_tasks = []

async def restart(self) -> None:
self.last_activity = {
"date": datetime.utcnow().isoformat() + "Z",
"execution_state": "starting",
}
for task in self.channel_tasks:
task.cancel()
self.channel_tasks = []
msg = create_message("shutdown_request", content={"restart": True})
send_message(msg, self.control_channel, self.key)
while True:
msg2 = await receive_message(self.control_channel)
assert msg2 is not None
if msg2["msg_type"] == "shutdown_reply" and msg2["content"]["restart"]:
break
await self._wait_for_ready()
self.channel_tasks += [
asyncio.create_task(self.listen_shell()),
asyncio.create_task(self.listen_control()),
asyncio.create_task(self.listen_iopub()),
]
await self.stop()
self.setup_connection_file()
await self.start()

async def serve(self, websocket: WebSocket, session_id: str):
self.sessions[session_id] = websocket
Expand Down

0 comments on commit 1bbdceb

Please sign in to comment.