diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index 0af1950e..c828ab8b 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -38,17 +38,9 @@ def __init__( ) -> None: self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path - if write_connection_file: - self.connection_file_path, self.connection_cfg = _write_connection_file( - connection_file - ) - elif connection_file: - self.connection_file_path = connection_file - self.connection_cfg = read_connection_file(connection_file) - else: - assert connection_cfg is not None - self.connection_cfg = connection_cfg - self.key = cast(str, self.connection_cfg["key"]) + self.connection_cfg = connection_cfg + self.connection_file = connection_file + self.write_connection_file = write_connection_file self.channel_tasks: List[asyncio.Task] = [] self.sessions: Dict[str, WebSocket] = {} # blocked messages and allowed messages are mutually exclusive @@ -57,6 +49,20 @@ def __init__( List[str] ] = None # when None, all messages are allowed # when [], no message is allowed + self.setup_connection_file() + + def setup_connection_file(self): + if self.write_connection_file: + self.connection_file_path, self.connection_cfg = _write_connection_file( + self.connection_file + ) + elif self.connection_file: + self.connection_file_path = self.connection_file + self.connection_cfg = read_connection_file(self.connection_file) + else: + if self.connection_cfg is None: + raise RuntimeError("No connection_cfg") + self.key = cast(str, self.connection_cfg["key"]) def block_messages(self, message_types: Iterable[str] = []): # if using blocked messages, discard allowed messages @@ -118,26 +124,9 @@ async def stop(self) -> None: self.channel_tasks = [] async def restart(self) -> None: - self.last_activity = { - "date": datetime.utcnow().isoformat() + "Z", - "execution_state": "starting", - } - for task in self.channel_tasks: - task.cancel() - self.channel_tasks = [] - msg = create_message("shutdown_request", content={"restart": True}) - send_message(msg, self.control_channel, self.key) - while True: - msg2 = await receive_message(self.control_channel) - assert msg2 is not None - if msg2["msg_type"] == "shutdown_reply" and msg2["content"]["restart"]: - break - await self._wait_for_ready() - self.channel_tasks += [ - asyncio.create_task(self.listen_shell()), - asyncio.create_task(self.listen_control()), - asyncio.create_task(self.listen_iopub()), - ] + await self.stop() + self.setup_connection_file() + await self.start() async def serve(self, websocket: WebSocket, session_id: str): self.sessions[session_id] = websocket