From 1bbdceb7a0eccf66dcd3e72f792168ed1ed0f8ad Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Jan 2022 09:05:25 +0100 Subject: [PATCH] Manually restart kernel as stop and start --- .../fps_kernels/kernel_server/server.py | 52 ++++++++----------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index 0af1950e..71335b70 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -38,17 +38,9 @@ def __init__( ) -> None: self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path - if write_connection_file: - self.connection_file_path, self.connection_cfg = _write_connection_file( - connection_file - ) - elif connection_file: - self.connection_file_path = connection_file - self.connection_cfg = read_connection_file(connection_file) - else: - assert connection_cfg is not None - self.connection_cfg = connection_cfg - self.key = cast(str, self.connection_cfg["key"]) + self.connection_cfg = connection_cfg + self.connection_file = connection_file + self.write_connection_file = write_connection_file self.channel_tasks: List[asyncio.Task] = [] self.sessions: Dict[str, WebSocket] = {} # blocked messages and allowed messages are mutually exclusive @@ -57,6 +49,20 @@ def __init__( List[str] ] = None # when None, all messages are allowed # when [], no message is allowed + self.setup_connection_file() + + def setup_connection_file(self): + if self.write_connection_file: + self.connection_file_path, self.connection_cfg = _write_connection_file( + self.connection_file + ) + elif self.connection_file: + self.connection_file_path = self.connection_file + self.connection_cfg = read_connection_file(self.connection_file) + else: + if self.connection_cfg is None: + raise RuntimeError("No connection_cfg") + self.key = cast(str, self.connection_cfg["key"]) def block_messages(self, message_types: Iterable[str] = []): # if using blocked messages, discard allowed messages @@ -91,6 +97,7 @@ async def start(self) -> None: self.kernel_process = await launch_kernel( self.kernelspec_path, self.connection_file_path, self.capture_kernel_output ) + assert self.connection_cfg is not None self.shell_channel = connect_channel("shell", self.connection_cfg) self.control_channel = connect_channel("control", self.connection_cfg) self.iopub_channel = connect_channel("iopub", self.connection_cfg) @@ -118,26 +125,9 @@ async def stop(self) -> None: self.channel_tasks = [] async def restart(self) -> None: - self.last_activity = { - "date": datetime.utcnow().isoformat() + "Z", - "execution_state": "starting", - } - for task in self.channel_tasks: - task.cancel() - self.channel_tasks = [] - msg = create_message("shutdown_request", content={"restart": True}) - send_message(msg, self.control_channel, self.key) - while True: - msg2 = await receive_message(self.control_channel) - assert msg2 is not None - if msg2["msg_type"] == "shutdown_reply" and msg2["content"]["restart"]: - break - await self._wait_for_ready() - self.channel_tasks += [ - asyncio.create_task(self.listen_shell()), - asyncio.create_task(self.listen_control()), - asyncio.create_task(self.listen_iopub()), - ] + await self.stop() + self.setup_connection_file() + await self.start() async def serve(self, websocket: WebSocket, session_id: str): self.sessions[session_id] = websocket