Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugins/remote_kernels/txl_remote_kernels/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import httpx
from anyio import Lock, sleep
from anyioutils import create_task
from httpx import USE_CLIENT_DEFAULT, Timeout
from httpx_ws import aconnect_ws
from txl_kernel.driver import KernelMixin
from txl_kernel.message import date_to_str
Expand All @@ -30,6 +31,7 @@ def __init__(
url: str,
kernel_name: str | None = "",
comm_handlers=[],
timeout: Timeout = USE_CLIENT_DEFAULT,
) -> None:
super().__init__(task_group)
self.task_group = task_group
Expand Down Expand Up @@ -78,6 +80,7 @@ async def start(self):
params={"session_id": self.session_id},
cookies=self.cookies,
subprotocols=["v1.kernel.websocket.jupyter.org"],
timeout=self.timeout,
) as self.websocket:
recv_task = create_task(self._recv(), self.task_group)
try:
Expand Down
18 changes: 15 additions & 3 deletions plugins/remote_kernels/txl_remote_kernels/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import httpx
from anyio import create_task_group, sleep
from fps import Module
from httpx import USE_CLIENT_DEFAULT, Timeout
from pycrdt import Map

from txl.base import Kernels, Kernelspecs
Expand All @@ -20,9 +21,11 @@ def __init__(
self,
url: str,
kernel_name: str | None,
*,
timeout: Timeout = USE_CLIENT_DEFAULT,
):
self.kernel = KernelDriver(
self.task_group, url, kernel_name, comm_handlers=self.comm_handlers
self.task_group, url, kernel_name, comm_handlers=self.comm_handlers, timeout=timeout
)

async def execute(self, ycell: Map):
Expand Down Expand Up @@ -56,19 +59,28 @@ async def get(self) -> dict[str, Any]:


class RemoteKernelsModule(Module):
def __init__(self, name: str, url: str = "http://127.0.0.1:8000"):
def __init__(
self,
name: str,
url: str = "http://127.0.0.1:8000",
*,
timeout: Timeout = USE_CLIENT_DEFAULT,
):
super().__init__(name)
self.url = url
self.timeout = timeout

async def start(self) -> None:
url = self.url

async with create_task_group() as self.tg:

class _RemoteKernels(RemoteKernels):
task_group = self.tg
timeout = self.timeout

def __init__(self, *args, **kwargs):
super().__init__(url, *args, **kwargs)
super().__init__(url, *args, timeout=self.timeout, **kwargs)

self.put(_RemoteKernels, Kernels)
self.done()
Expand Down
Loading