Skip to content

feat:support nbmodel #1535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@

import asyncio
import datetime
import inspect
import json
import os
import time
import typing as t
from queue import Empty, Queue
from threading import Thread
from time import monotonic
from turtle import st
from types import CoroutineType, coroutine
from typing import TYPE_CHECKING, Any, Optional, cast

import websocket
Expand Down Expand Up @@ -642,6 +647,8 @@ async def get_msg(self, *args: Any, **kwargs: Any) -> dict[str, Any]:

def send(self, msg: dict[str, Any]) -> None:
"""Send a message to the queue."""
if "channel" not in msg:
msg["channel"] = self.channel_name
message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace("</", "<\\/")
self.log.debug(
"Sending message on channel: %s, msg_id: %s, msg_type: %s",
Expand Down Expand Up @@ -683,6 +690,9 @@ def is_alive(self) -> bool:
"""Whether the queue is alive."""
return self.channel_socket is not None

async def msg_ready(self) -> bool:
return not self.empty()


class HBChannelQueue(ChannelQueue):
"""A queue for the heartbeat channel."""
Expand Down Expand Up @@ -877,5 +887,187 @@ def _route_responses(self):

self.log.debug("Response router thread exiting...")

async def _maybe_awaitable(self, func_result):
"""Helper to handle potentially awaitable results"""
if inspect.isawaitable(func_result):
await func_result

async def _handle_iopub_stdin_messages(
self,
msg_id: str,
output_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]],
stdin_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]],
timeout: t.Optional[float],
allow_stdin: bool,
start_time: float,
) -> None:
"""Handle IOPub messages until idle state"""
while True:
# Calculate remaining timeout
if timeout is not None:
elapsed = time.monotonic() - start_time
remaining = max(0, timeout - elapsed)
if remaining <= 0:
raise TimeoutError("Timeout in IOPub handling")
else:
remaining = None
if stdin_hook is not None and allow_stdin:
await self._handle_stdin_messages(stdin_hook, allow_stdin)
try:
msg = await self.iopub_channel.get_msg(timeout=remaining)
except Exception as e:
self.log.warning(f"err ({e})")

if msg["parent_header"].get("msg_id") != msg_id:
continue

if output_hook is not None:
await self._maybe_awaitable(output_hook(msg))

if (
msg["header"]["msg_type"] == "status"
and msg["content"].get("execution_state") == "idle"
):
break

async def _handle_stdin_messages(
self,
stdin_hook: t.Callable[[dict[str, t.Any]], t.Any],
allow_stdin: bool,
) -> None:
"""Handle stdin messages until iopub is idle"""
if not allow_stdin:
return
try:
msg = await self.stdin_channel.get_msg(timeout=0.01)
self.log.info(f"stdin msg: {msg},{type(msg)}")
await self._maybe_awaitable(stdin_hook(msg))
except (Empty, TimeoutError):
pass
except Exception:
self.log.warning("Error handling stdin message", exc_info=True)

async def _wait_for_execution_reply(
self, msg_id: str, timeout: t.Optional[float], start_time: float
) -> dict[str, t.Any]:
"""Wait for execution reply from shell or control channel"""
# Calculate remaining timeout
if timeout is not None:
elapsed = time.monotonic() - start_time
remaining_timeout = max(0, timeout - elapsed)
if remaining_timeout <= 0:
raise TimeoutError("Timeout waiting for reply")
else:
remaining_timeout = None

deadline = time.monotonic() + remaining_timeout if remaining_timeout else None

while True:
if deadline:
remaining = max(0, deadline - time.monotonic())
if remaining <= 0:
raise TimeoutError("Timeout waiting for reply")
else:
remaining = None

# Listen to both shell and control channels
reply_task = asyncio.create_task(self.shell_channel.get_msg(timeout=remaining))
control_task = asyncio.create_task(self.control_channel.get_msg(timeout=remaining))

try:
done, pending = await asyncio.wait(
[reply_task, control_task],
timeout=remaining,
return_when=asyncio.FIRST_COMPLETED,
)

# Cancel pending tasks
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

if not done:
raise TimeoutError("Timeout waiting for reply")

for task in done:
try:
msg: dict[str, t.Any] = task.result()
if msg["parent_header"].get("msg_id") == msg_id:
return msg
except Exception:
continue

except asyncio.TimeoutError as err:
reply_task.cancel()
control_task.cancel()
raise TimeoutError("Timeout waiting for reply") from err

async def execute_interactive(
self,
code: str,
silent: bool = False,
store_history: bool = True,
user_expressions: t.Optional[dict[str, t.Any]] = None,
allow_stdin: t.Optional[bool] = None,
stop_on_error: bool = True,
timeout: t.Optional[float] = None,
output_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]] = None,
stdin_hook: t.Optional[t.Callable[[dict[str, t.Any]], t.Any]] = None,
) -> dict[str, t.Any]: # type: ignore[override] # Reason: base class sets `execute_interactive` via assignment, so mypy cannot infer override compatibility
"""Execute code in the kernel interactively via gateway"""

# Channel alive checks
if not self.iopub_channel.is_alive():
raise RuntimeError("IOPub channel must be running to receive output")

# Prepare defaults
if allow_stdin is None:
allow_stdin = self.allow_stdin

if output_hook is None:
output_hook = self._output_hook_default
if stdin_hook is None:
stdin_hook = self._stdin_hook_default

# Execute the code
msg_id = self.execute(
code=code,
silent=silent,
store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin,
stop_on_error=stop_on_error,
)

# Setup coordination
start_time = time.monotonic()

try:
# Handle IOPub messages until idle
iopub_task = asyncio.create_task(
self._handle_iopub_stdin_messages(
msg_id, output_hook, stdin_hook, timeout, allow_stdin, start_time
),
name="handle_iopub_stdin_messages",
)
await iopub_task
# Get the execution reply
reply = await self._wait_for_execution_reply(msg_id, timeout, start_time)
return reply

except asyncio.CancelledError:
raise
except TimeoutError:
raise
except Exception as e:
self.log.error(
f"Error during interactive execution: {e}, msg_id: {msg_id}",
exc_info=True,
)
raise RuntimeError(f"Error in interactive execution: {e}") from e


KernelClientABC.register(GatewayKernelClient)
98 changes: 96 additions & 2 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@
from traitlets.config import Config

from jupyter_server.gateway.connections import GatewayWebSocketConnection
from jupyter_server.gateway.gateway_client import GatewayTokenRenewerBase, NoOpTokenRenewer
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient, GatewayKernelManager
from jupyter_server.gateway.gateway_client import (
GatewayTokenRenewerBase,
NoOpTokenRenewer,
)
from jupyter_server.gateway.managers import (
ChannelQueue,
GatewayClient,
GatewayKernelClient,
GatewayKernelManager,
)
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler

from .utils import expected_http_error
Expand Down Expand Up @@ -902,3 +910,89 @@ async def delete_kernel(jp_fetch, kernel_id):
r = await jp_fetch("api", "kernels", kernel_id, method="DELETE")
assert r.code == 204
assert r.reason == "No Content"


@pytest.fixture
def mock_channel_queue():
queue = ChannelQueue("shell", MagicMock(), MagicMock())
return queue


@pytest.fixture
def gateway_kernel_client(init_gateway, monkeypatch):
client = GatewayKernelClient("fake-kernel-id")
client._channel_queues = {
"shell": ChannelQueue("shell", MagicMock(), MagicMock()),
"iopub": ChannelQueue("iopub", MagicMock(), MagicMock()),
"stdin": ChannelQueue("stdin", MagicMock(), MagicMock()),
"hb": ChannelQueue("hb", MagicMock(), MagicMock()),
"control": ChannelQueue("control", MagicMock(), MagicMock()),
}
client._shell_channel = client._channel_queues["shell"]
client._iopub_channel = client._channel_queues["iopub"]
client._stdin_channel = client._channel_queues["stdin"]
client._hb_channel = client._channel_queues["hb"]
client._control_channel = client._channel_queues["control"]
return client


def fake_create_connection(*args, **kwargs):
return MagicMock()


async def test_gateway_kernel_client_start_and_stop_channels(gateway_kernel_client, monkeypatch):
monkeypatch.setattr("websocket.create_connection", fake_create_connection)
monkeypatch.setattr(gateway_kernel_client, "channel_socket", MagicMock())
monkeypatch.setattr(gateway_kernel_client, "response_router", MagicMock())
await gateway_kernel_client.start_channels()
gateway_kernel_client.stop_channels()
assert gateway_kernel_client._channels_stopped


# @pytest.mark.asyncio
async def test_gateway_kernel_client_execute_interactive(gateway_kernel_client, monkeypatch):
gateway_kernel_client.execute = MagicMock(return_value="msg-123")

async def fake_shell_get_msg(timeout=None):
return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "execute_reply"}

gateway_kernel_client.shell_channel.get_msg = fake_shell_get_msg

async def fake_iopub_get_msg(timeout=None):
await asyncio.sleep(0.01)
return {
"parent_header": {"msg_id": "msg-123"},
"msg_type": "status",
"header": {"msg_type": "status"},
"content": {"execution_state": "idle"},
}

gateway_kernel_client.iopub_channel.get_msg = fake_iopub_get_msg

async def fake_stdin_get_msg(timeout=None):
await asyncio.sleep(0.01)
return {"parent_header": {"msg_id": "msg-123"}, "msg_type": "input_request"}

gateway_kernel_client.stdin_channel.get_msg = fake_stdin_get_msg
output_msgs = []

async def output_hook(msg):
output_msgs.append(msg)

stdin_msgs = []

async def stdin_hook(msg):
stdin_msgs.append(msg)

reply = await gateway_kernel_client.execute_interactive(
"print(1)", output_hook=output_hook, stdin_hook=stdin_hook
)
assert reply["msg_type"] == "execute_reply"


async def test_gateway_channel_queue_get_msg_with_response_router_finished(
mock_channel_queue,
):
mock_channel_queue.response_router_finished = True
with pytest.raises(RuntimeError):
await mock_channel_queue.get_msg(timeout=0.1)
Loading