Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ To add a kernel spec for a `trio` backend.

```bash
pip install trio
async-kernel -a async-trio
async-kernel -a async-trio --interface.backend=trio
```

For further detail about kernel customisation see [command line usage](https://fleming79.github.io/async-kernel/latest/commands/#command-line).
For further detail about kernel spec customisation see [command line usage](https://fleming79.github.io/async-kernel/latest/commands/#command-line).

## Message handling

Expand Down
4 changes: 1 addition & 3 deletions src/async_kernel/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import async_kernel
from async_kernel.kernelspec import get_kernel_dir, import_start_interface, remove_kernel_spec, write_kernel_spec
from async_kernel.typing import KernelName

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -68,8 +67,7 @@ def command_line() -> None:
"-a",
"--add",
dest="add",
help=f"Add a kernel spec. Default kernel names are: {list(map(str, KernelName))}.\n"
+ "To specify a 'trio' backend, include 'trio' in the name. Other options are also permitted. See: `write_kernel_spec` for detail.",
help="Write a kernel spec with the corresponding name. This will overwrite existing kernel specs of the same name.",
)
kernels = [] if not kernel_dir.exists() else [item.name for item in kernel_dir.iterdir() if item.is_dir()]
parser.add_argument(
Expand Down
3 changes: 0 additions & 3 deletions src/async_kernel/interface/zmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ async def run_kernel() -> None:
async with self.kernel:
await self.wait_exit

if not self.trait_has_value("backend") and "trio" in self.kernel.kernel_name.lower():
self.backend = Backend.trio

anyio.run(run_kernel, backend=self.backend, backend_options=self.backend_options)

@override
Expand Down
22 changes: 3 additions & 19 deletions src/async_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,7 @@
from async_kernel.common import Fixed
from async_kernel.debugger import Debugger
from async_kernel.interface.base import BaseKernelInterface
from async_kernel.typing import (
Channel,
Content,
ExecuteContent,
HandlerType,
Job,
KernelName,
Message,
MsgType,
NoValue,
RunMode,
)
from async_kernel.typing import Channel, Content, ExecuteContent, HandlerType, Job, Message, MsgType, NoValue, RunMode

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Iterable
Expand Down Expand Up @@ -197,13 +186,8 @@ def _default_log(self) -> LoggerAdapter[Logger]:
return logging.LoggerAdapter(logging.getLogger(self.__class__.__name__))

@traitlets.default("kernel_name")
def _default_kernel_name(self) -> Literal[KernelName.trio, KernelName.asyncio]:
try:
if current_async_library() == "trio":
return KernelName.trio
except Exception:
pass
return KernelName.asyncio
def _default_kernel_name(self):
return "async-trio" if current_async_library(failsafe=True) == "trio" else "async"

@traitlets.default("interface")
def default_interface(self):
Expand Down
6 changes: 0 additions & 6 deletions src/async_kernel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"FixedCreated",
"HandlerType",
"Job",
"KernelName",
"Message",
"MsgHeader",
"MsgType",
Expand All @@ -50,11 +49,6 @@ class Backend(enum.StrEnum):
trio = "trio"


class KernelName(enum.StrEnum):
asyncio = "async"
trio = "async-trio"


class Channel(enum.StrEnum):
"An enum of channels[Ref](https://jupyter-client.readthedocs.io/en/stable/messaging.html#introduction)."

Expand Down
11 changes: 6 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from async_kernel.interface.zmq import ZMQKernelInterface
from async_kernel.kernel import Kernel
from async_kernel.kernelspec import make_argv
from async_kernel.typing import Backend, Channel, ExecuteContent, Job, KernelName, Message, MsgHeader, MsgType
from async_kernel.typing import Backend, Channel, ExecuteContent, Job, Message, MsgHeader, MsgType
from tests import utils

if TYPE_CHECKING:
Expand Down Expand Up @@ -110,20 +110,21 @@ async def client(kernel: Kernel) -> AsyncGenerator[AsyncKernelClient, Any]:
await anyio.sleep(0)


@pytest.fixture(scope="module", params=KernelName)
@pytest.fixture(scope="module", params=["async", "async-trio"])
def kernel_name(request):
return request.param


@pytest.fixture(scope="module")
async def subprocess_kernels_client(anyio_backend, tmp_path_factory, kernel_name: KernelName, transport):
async def subprocess_kernels_client(anyio_backend, tmp_path_factory, kernel_name, transport: str):
"""
Starts a kernel in a subprocess and returns an AsyncKernelCient that is connected to it.
"""
assert anyio_backend[0] == "asyncio", "Asyncio is required for the client"
connection_file = tmp_path_factory.mktemp("async_kernel") / "temp_connection.json"
kwgs = {"interface.transport": transport}
command = make_argv(connection_file=connection_file, kernel_name=kernel_name, **kwgs)
backend = "trio" if "trio" in kernel_name else "asyncio"
kwgs = {"interface.transport": transport, "interface.backend": backend}
command = make_argv(connection_file=connection_file, kernel_name=kernel_name, **kwgs) # pyright: ignore[reportArgumentType]
process = await anyio.open_process([*command, "--no-print_kernel_messages"])
async with process:
while not connection_file.exists() or not connection_file.stat().st_size:
Expand Down
16 changes: 3 additions & 13 deletions tests/test_enter_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,12 @@
import pytest

from async_kernel.kernel import Kernel
from async_kernel.typing import Backend, KernelName


@pytest.fixture(scope="module", params=list(KernelName))
def kernel_name(request):
return request.param


@pytest.fixture(scope="module")
def anyio_backend(kernel_name: KernelName):
return "trio" if kernel_name is KernelName.trio else "asyncio"


async def test_start_kernel_in_context(anyio_backend: Backend, kernel_name: KernelName):
@pytest.mark.parametrize("anyio_backend", argvalues=["asyncio", "trio"])
async def test_start_kernel_in_context(anyio_backend):
async with Kernel({"print_kernel_messages": False}) as kernel:
assert kernel.kernel_name == kernel_name
assert kernel.kernel_name == {"asyncio": "async", "trio": "async-trio"}[anyio_backend]
connection_file = kernel.connection_file
# Test prohibit nested async context.
with pytest.raises(RuntimeError, match="this Kernel has already been entered"):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_kernelspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@

from async_kernel.interface.zmq import ZMQKernelInterface
from async_kernel.kernelspec import DEFAULT_START_INTERFACE, import_start_interface, write_kernel_spec
from async_kernel.typing import KernelName


@pytest.mark.parametrize(
("kernel_name", "start_interface"),
[
(KernelName.trio, DEFAULT_START_INTERFACE),
("trio", DEFAULT_START_INTERFACE),
("function_factory", "custom"),
],
)
def test_write_kernel_spec(kernel_name: KernelName, start_interface, tmp_path, monkeypatch):
def test_write_kernel_spec(kernel_name, start_interface, tmp_path, monkeypatch):
if start_interface == "custom":

def my_start_interface(settings: dict | None):
Expand Down
Loading