diff --git a/.gitignore b/.gitignore index 70e4192c..d4751320 100644 --- a/.gitignore +++ b/.gitignore @@ -344,6 +344,3 @@ $RECYCLE.BIN/ .jupyter_ystore.db .jupyter_ystore.db-journal fps_cli_args.toml - -# pixi environments -.pixi diff --git a/jupyverse_api/jupyverse_api/cli.py b/jupyverse_api/jupyverse_api/cli.py index 20dd8a2f..bdb0ff03 100644 --- a/jupyverse_api/jupyverse_api/cli.py +++ b/jupyverse_api/jupyverse_api/cli.py @@ -2,7 +2,7 @@ from typing import List, Tuple import rich_click as click -from asphalt.core.cli import run +from asphalt.core._cli import run if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -66,8 +66,6 @@ def main( set_list.append(f"component.allow_origin={allow_origin}") config = get_config(disable) run.callback( - unsafe=False, - loop=None, set_=set_list, service=None, configfile=[config], diff --git a/jupyverse_api/jupyverse_api/contents/__init__.py b/jupyverse_api/jupyverse_api/contents/__init__.py index 4296e29b..c9af6975 100644 --- a/jupyverse_api/jupyverse_api/contents/__init__.py +++ b/jupyverse_api/jupyverse_api/contents/__init__.py @@ -1,4 +1,3 @@ -import asyncio from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional, Union @@ -13,8 +12,13 @@ class FileIdManager(ABC): - stop_watching_files: asyncio.Event - stopped_watching_files: asyncio.Event + @abstractmethod + async def start(self) -> None: + ... + + @abstractmethod + async def stop(self) -> None: + ... @abstractmethod async def get_path(self, file_id: str) -> str: diff --git a/jupyverse_api/jupyverse_api/main/__init__.py b/jupyverse_api/jupyverse_api/main/__init__.py index cfbaa3a3..7656c2b6 100644 --- a/jupyverse_api/jupyverse_api/main/__init__.py +++ b/jupyverse_api/jupyverse_api/main/__init__.py @@ -3,8 +3,9 @@ import webbrowser from typing import Any, Callable, Dict, Sequence, Tuple +from anyio import Event from asgiref.typing import ASGI3Application -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource, start_service_task from asphalt.web.fastapi import FastAPIComponent from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -22,14 +23,11 @@ def __init__( super().__init__() self.mount_path = mount_path - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(FastAPI) + async def start(self) -> None: + app = await get_resource(FastAPI, wait=True) _app = App(app, mount_path=self.mount_path) - ctx.add_resource(_app) + add_resource(_app) class JupyverseComponent(FastAPIComponent): @@ -67,22 +65,27 @@ def __init__( self.port = port self.open_browser = open_browser self.query_params = query_params + self.lifespan = Lifespan() - async def start( - self, - ctx: Context, - ) -> None: + async def start(self) -> None: query_params = QueryParams(d={}) host = self.host if not host.startswith("http"): host = f"http://{host}" host_url = Host(url=f"{host}:{self.port}/") - ctx.add_resource(query_params) - ctx.add_resource(host_url) + add_resource(query_params) + add_resource(host_url) + add_resource(self.lifespan) - await super().start(ctx) + await super().start() # at this point, the server has started + await start_service_task( + self.lifespan.shutdown_request.wait, + "Server lifespan notifier", + teardown_action=self.lifespan.shutdown_request.set, + ) + if self.open_browser: qp = query_params.d if self.query_params: @@ -97,3 +100,8 @@ class QueryParams(BaseModel): class Host(BaseModel): url: str + + +class Lifespan: + def __init__(self): + self.shutdown_request = Event() diff --git a/jupyverse_api/pyproject.toml b/jupyverse_api/pyproject.toml index a8168883..ee1ea543 100644 --- a/jupyverse_api/pyproject.toml +++ b/jupyverse_api/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] @@ -28,8 +29,9 @@ dependencies = [ "pydantic >=2,<3", "fastapi >=0.95.0,<1", "rich-click >=1.6.1,<2", - "asphalt >=4.11.0,<5", - "asphalt-web[fastapi] >=1.1.0,<2", + "importlib_metadata >=3.6; python_version<'3.10'", + #"asphalt >=4.11.0,<5", + #"asphalt-web[fastapi] >=1.1.0,<2", ] dynamic = ["version"] diff --git a/plugins/auth/fps_auth/main.py b/plugins/auth/fps_auth/main.py index 8670a084..df086639 100644 --- a/plugins/auth/fps_auth/main.py +++ b/plugins/auth/fps_auth/main.py @@ -1,6 +1,6 @@ import logging -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from fastapi_users.exceptions import UserAlreadyExists from jupyverse_api.app import App @@ -18,17 +18,14 @@ class AuthComponent(Component): def __init__(self, **kwargs): self.auth_config = _AuthConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.auth_config, types=AuthConfig) + async def start(self) -> None: + add_resource(self.auth_config, types=AuthConfig) - app = await ctx.request_resource(App) - frontend_config = await ctx.request_resource(FrontendConfig) + app = await get_resource(App, wait=True) + frontend_config = await get_resource(FrontendConfig, wait=True) auth = auth_factory(app, self.auth_config, frontend_config) - ctx.add_resource(auth, types=Auth) + add_resource(auth, types=Auth) await auth.db.create_db_and_tables() @@ -59,8 +56,8 @@ async def start( ) if self.auth_config.mode == "token": - query_params = await ctx.request_resource(QueryParams) - host = await ctx.request_resource(Host) + query_params = await get_resource(QueryParams, wait=True) + host = await get_resource(Host, wait=True) query_params.d["token"] = self.auth_config.token logger.info("") diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index dd24863f..5b5ab11a 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -72,6 +72,7 @@ async def get_users( async def get_api_me( permissions: Optional[str] = None, user: UserRead = Depends(backend.current_user()), + update_user = Depends(backend.update_user), ): checked_permissions: Dict[str, List[str]] = {} if permissions is None: @@ -96,6 +97,14 @@ async def get_api_me( moon = get_anonymous_username() identity["name"] = f"Anonymous {moon}" identity["display_name"] = f"Anonymous {moon}" + identity["initials"] = f"A{moon[0]}" + await update_user( + dict( + name=identity["name"], + display_name=identity["display_name"], + permissions=checked_permissions, + ) + ) return { "identity": identity, "permissions": checked_permissions, diff --git a/plugins/auth_fief/fps_auth_fief/main.py b/plugins/auth_fief/fps_auth_fief/main.py index ddc8224f..04632df9 100644 --- a/plugins/auth_fief/fps_auth_fief/main.py +++ b/plugins/auth_fief/fps_auth_fief/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth, AuthConfig @@ -11,13 +11,10 @@ class AuthFiefComponent(Component): def __init__(self, **kwargs): self.auth_fief_config = _AuthFiefConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.auth_fief_config, types=AuthConfig) + async def start(self) -> None: + add_resource(self.auth_fief_config, types=AuthConfig) - app = await ctx.request_resource(App) + app = await get_resource(App, wait=True) auth_fief = auth_factory(app, self.auth_fief_config) - ctx.add_resource(auth_fief, types=Auth) + add_resource(auth_fief, types=Auth) diff --git a/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py b/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py index 8b9a8b71..27f9e120 100644 --- a/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py +++ b/plugins/auth_jupyterhub/fps_auth_jupyterhub/main.py @@ -1,5 +1,10 @@ -import httpx -from asphalt.core import Component, ContainerComponent, Context, context_teardown +from asphalt.core import ( + Component, + ContainerComponent, + add_resource, + get_resource, + start_service_task, +) from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from jupyverse_api.app import App @@ -11,40 +16,29 @@ class _AuthJupyterHubComponent(Component): - @context_teardown - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - db_session = await ctx.request_resource(AsyncSession) - db_engine = await ctx.request_resource(AsyncEngine) - - http_client = httpx.AsyncClient() - auth_jupyterhub = auth_factory(app, db_session, http_client) - ctx.add_resource(auth_jupyterhub, types=Auth) + async def start(self) -> None: + app = await get_resource(App, wait=True) + db_session = await get_resource(AsyncSession, wait=True) + db_engine = await get_resource(AsyncEngine, wait=True) + + auth_jupyterhub = auth_factory(app, db_session) + await start_service_task(auth_jupyterhub.start, "JupyterHub Auth", auth_jupyterhub.stop) + add_resource(auth_jupyterhub, types=Auth) async with db_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - yield - - await http_client.aclose() - class AuthJupyterHubComponent(ContainerComponent): def __init__(self, **kwargs): self.auth_jupyterhub_config = AuthJupyterHubConfig(**kwargs) super().__init__() - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.auth_jupyterhub_config, types=AuthConfig) + async def start(self) -> None: + add_resource(self.auth_jupyterhub_config, types=AuthConfig) self.add_component( "sqlalchemy", url=self.auth_jupyterhub_config.db_url, ) self.add_component("auth_jupyterhub", type=_AuthJupyterHubComponent) - await super().start(ctx) + await super().start() diff --git a/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py b/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py index e0cb72e6..b132b2c2 100644 --- a/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py +++ b/plugins/auth_jupyterhub/fps_auth_jupyterhub/routes.py @@ -1,14 +1,16 @@ from __future__ import annotations -import asyncio import json import os from datetime import datetime +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import httpx +from anyio import TASK_STATUS_IGNORED, Lock, create_task_group +from anyio.abc import TaskStatus from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, WebSocket, status from fastapi.responses import RedirectResponse +from httpx import AsyncClient from jupyterhub.services.auth import HubOAuth from jupyterhub.utils import isoformat from sqlalchemy.ext.asyncio import AsyncSession @@ -26,16 +28,15 @@ def auth_factory( app: App, db_session: AsyncSession, - http_client: httpx.AsyncClient, ): class AuthJupyterHub(Auth, Router): def __init__(self) -> None: super().__init__(app) self.hub_auth = HubOAuth() - self.db_lock = asyncio.Lock() + self.db_lock = Lock() self.activity_url = os.environ.get("JUPYTERHUB_ACTIVITY_URL") self.server_name = os.environ.get("JUPYTERHUB_SERVER_NAME") - self.background_tasks = set() + self.http_client = AsyncClient() router = APIRouter() @@ -123,8 +124,9 @@ async def _( "Content-Type": "application/json", } last_activity = isoformat(datetime.utcnow()) - task = asyncio.create_task( - http_client.post( + self.task_group.start_soon( + partial( + self.http_client.post, self.activity_url, headers=headers, json={ @@ -132,8 +134,6 @@ async def _( }, ) ) - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) return user if permissions: @@ -193,4 +193,13 @@ async def _( return _ + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + async with create_task_group() as tg: + self.task_group = tg + task_status.started() + + async def stop(self) -> None: + await self.http_client.aclose() + self.task_group.cancel_scope().cancel() + return AuthJupyterHub() diff --git a/plugins/auth_jupyterhub/pyproject.toml b/plugins/auth_jupyterhub/pyproject.toml index 62b5037b..b4121812 100644 --- a/plugins/auth_jupyterhub/pyproject.toml +++ b/plugins/auth_jupyterhub/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "httpx >=0.24.1,<1", "jupyterhub >=4.0.1,<5", "jupyverse-api >=0.1.2,<1", + "anyio", ] [[project.authors]] diff --git a/plugins/contents/fps_contents/fileid.py b/plugins/contents/fps_contents/fileid.py index f489c59d..6d6a78fc 100644 --- a/plugins/contents/fps_contents/fileid.py +++ b/plugins/contents/fps_contents/fileid.py @@ -1,28 +1,28 @@ -import asyncio +from __future__ import annotations + import logging +import sqlite3 from typing import Dict, List, Optional, Set from uuid import uuid4 -import aiosqlite -from anyio import Path +from anyio import Event, Lock, Path +from sqlite_anyio import connect from watchfiles import Change, awatch -from jupyverse_api import Singleton - logger = logging.getLogger("contents") class Watcher: def __init__(self, path: str) -> None: self.path = path - self._event = asyncio.Event() + self._event = Event() def __aiter__(self): return self async def __anext__(self): await self._event.wait() - self._event.clear() + self._event = Event() return self._change def notify(self, change): @@ -30,128 +30,133 @@ def notify(self, change): self._event.set() -class FileIdManager(metaclass=Singleton): +class FileIdManager: db_path: str - initialized: asyncio.Event + initialized: Event watchers: Dict[str, List[Watcher]] - lock: asyncio.Lock + lock: Lock def __init__(self, db_path: str = ".fileid.db"): self.db_path = db_path - self.initialized = asyncio.Event() + self.initialized = Event() self.watchers = {} - self.watch_files_task = asyncio.create_task(self.watch_files()) - self.stop_watching_files = asyncio.Event() - self.stopped_watching_files = asyncio.Event() - self.lock = asyncio.Lock() + self.stop_event = Event() + self.lock = Lock() + + async def start(self) -> None: + self._db = await connect(self.db_path) + try: + await self.watch_files() + except sqlite3.ProgrammingError: + pass + + async def stop(self) -> None: + await self._db.close() + self.stop_event.set() async def get_id(self, path: str) -> Optional[str]: await self.initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute("SELECT id FROM fileids WHERE path = ?", (path,)) as cursor: - async for (idx,) in cursor: - return idx - return None + cursor = await self._db.cursor() + await cursor.execute("SELECT id FROM fileids WHERE path = ?", (path,)) + for (idx,) in await cursor.fetchall(): + return idx + return None async def get_path(self, idx: str) -> Optional[str]: await self.initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute("SELECT path FROM fileids WHERE id = ?", (idx,)) as cursor: - async for (path,) in cursor: - return path - return None + cursor = await self._db.cursor() + await cursor.execute("SELECT path FROM fileids WHERE id = ?", (idx,)) + for (path,) in await cursor.fetchall(): + return path + return None async def index(self, path: str) -> Optional[str]: await self.initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - apath = Path(path) - if not await apath.exists(): - return None + apath = Path(path) + if not await apath.exists(): + return None - idx = uuid4().hex - mtime = (await apath.stat()).st_mtime - await db.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, path, mtime)) - await db.commit() - return idx + idx = uuid4().hex + mtime = (await apath.stat()).st_mtime + cursor = await self._db.cursor() + await cursor.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, path, mtime)) + await self._db.commit() + return idx async def watch_files(self): async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute("DROP TABLE IF EXISTS fileids") - await db.execute( - "CREATE TABLE fileids " - "(id TEXT PRIMARY KEY, path TEXT NOT NULL UNIQUE, mtime REAL NOT NULL)" - ) - await db.commit() + cursor = await self._db.cursor() + await cursor.execute("DROP TABLE IF EXISTS fileids") + await cursor.execute( + "CREATE TABLE fileids " + "(id TEXT PRIMARY KEY, path TEXT NOT NULL UNIQUE, mtime REAL NOT NULL)" + ) + await self._db.commit() # index files async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async for path in Path().rglob("*"): - idx = uuid4().hex - mtime = (await path.stat()).st_mtime - await db.execute( - "INSERT INTO fileids VALUES (?, ?, ?)", (idx, str(path), mtime) - ) - await db.commit() - self.initialized.set() - - async for changes in awatch(".", stop_event=self.stop_watching_files): + cursor = await self._db.cursor() + async for path in Path().rglob("*"): + idx = uuid4().hex + mtime = (await path.stat()).st_mtime + await cursor.execute( + "INSERT INTO fileids VALUES (?, ?, ?)", (idx, str(path), mtime) + ) + await self._db.commit() + self.initialized.set() + + async for changes in awatch(".", stop_event=self.stop_event): async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - deleted_paths = set() - added_paths = set() - for change, changed_path in changes: - # get relative path - changed_path = Path(changed_path).relative_to(await Path().absolute()) - changed_path_str = str(changed_path) - - if change == Change.deleted: - logger.debug("File %s was deleted", changed_path_str) - async with db.execute( - "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) - ) as cursor: - if not (await cursor.fetchone())[0]: - # path is not indexed, ignore - logger.debug( - "File %s is not indexed, ignoring", changed_path_str - ) - continue - # path is indexed - await maybe_rename( - db, changed_path_str, deleted_paths, added_paths, False - ) - elif change == Change.added: - logger.debug("File %s was added", changed_path_str) - await maybe_rename( - db, changed_path_str, added_paths, deleted_paths, True - ) - elif change == Change.modified: - logger.debug("File %s was modified", changed_path_str) - if changed_path_str == self.db_path: - continue - async with db.execute( - "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) - ) as cursor: - if not (await cursor.fetchone())[0]: - # path is not indexed, ignore - logger.debug( - "File %s is not indexed, ignoring", changed_path_str - ) - continue - mtime = (await changed_path.stat()).st_mtime - await db.execute( - "UPDATE fileids SET mtime = ? WHERE path = ?", - (mtime, changed_path_str), - ) - - for path in deleted_paths - added_paths: - logger.debug("Unindexing file %s ", path) - await db.execute("DELETE FROM fileids WHERE path = ?", (path,)) - await db.commit() + deleted_paths = set() + added_paths = set() + cursor = await self._db.cursor() + for change, changed_path in changes: + # get relative path + changed_path = Path(changed_path).relative_to(await Path().absolute()) + changed_path_str = str(changed_path) + + if change == Change.deleted: + logger.debug("File %s was deleted", changed_path_str) + await cursor.execute( + "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) + ) + if not (await cursor.fetchone())[0]: + # path is not indexed, ignore + logger.debug("File %s is not indexed, ignoring", changed_path_str) + continue + # path is indexed + await maybe_rename( + self._db, changed_path_str, deleted_paths, added_paths, False + ) + elif change == Change.added: + logger.debug("File %s was added", changed_path_str) + await maybe_rename( + self._db, changed_path_str, added_paths, deleted_paths, True + ) + elif change == Change.modified: + logger.debug("File %s was modified", changed_path_str) + if changed_path_str == self.db_path: + continue + await cursor.execute( + "SELECT COUNT(*) FROM fileids WHERE path = ?", (changed_path_str,) + ) + if not (await cursor.fetchone())[0]: + # path is not indexed, ignore + logger.debug("File %s is not indexed, ignoring", changed_path_str) + continue + mtime = (await changed_path.stat()).st_mtime + await cursor.execute( + "UPDATE fileids SET mtime = ? WHERE path = ?", + (mtime, changed_path_str), + ) + + for path in deleted_paths - added_paths: + logger.debug("Unindexing file %s ", path) + await cursor.execute("DELETE FROM fileids WHERE path = ?", (path,)) + await self._db.commit() for change in changes: changed_path = change[1] @@ -161,8 +166,6 @@ async def watch_files(self): for watcher in self.watchers.get(relative_changed_path, []): watcher.notify(relative_change) - self.stopped_watching_files.set() - def watch(self, path: str) -> Watcher: watcher = Watcher(path) self.watchers.setdefault(path, []).append(watcher) @@ -174,11 +177,12 @@ def unwatch(self, path: str, watcher: Watcher): async def get_mtime(path, db) -> Optional[float]: if db: - async with db.execute("SELECT mtime FROM fileids WHERE path = ?", (path,)) as cursor: - async for (mtime,) in cursor: - return mtime - # deleted file is not in database, shouldn't happen - return None + cursor = await db.cursor() + await cursor.execute("SELECT mtime FROM fileids WHERE path = ?", (path,)) + for (mtime,) in await cursor.fetchall(): + return mtime + # deleted file is not in database, shouldn't happen + return None try: mtime = (await Path(path).stat()).st_mtime except FileNotFoundError: @@ -204,7 +208,8 @@ async def maybe_rename( if is_added_path: path1, path2 = path2, path1 logger.debug("File %s was renamed to %s", path1, path2) - await db.execute("UPDATE fileids SET path = ? WHERE path = ?", (path2, path1)) + cursor = await db.cursor() + await cursor.execute("UPDATE fileids SET path = ? WHERE path = ?", (path2, path1)) other_paths.remove(other_path) return changed_paths.add(changed_path) diff --git a/plugins/contents/fps_contents/main.py b/plugins/contents/fps_contents/main.py index 81f6fd18..8582ce89 100644 --- a/plugins/contents/fps_contents/main.py +++ b/plugins/contents/fps_contents/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -8,12 +8,9 @@ class ContentsComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] contents = _Contents(app, auth) - ctx.add_resource(contents, types=Contents) + add_resource(contents, types=Contents) diff --git a/plugins/contents/fps_contents/routes.py b/plugins/contents/fps_contents/routes.py index 7f759b9e..bc7667ed 100644 --- a/plugins/contents/fps_contents/routes.py +++ b/plugins/contents/fps_contents/routes.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import base64 import json import os import shutil -from datetime import datetime +from datetime import datetime, timezone from http import HTTPStatus from pathlib import Path from typing import Dict, List, Optional, Union, cast @@ -25,6 +27,8 @@ class _Contents(Contents): + _file_id_manager: FileIdManager | None = None + async def create_checkpoint( self, path, @@ -245,7 +249,9 @@ async def write_content(self, content: Union[SaveContent, Dict]) -> None: @property def file_id_manager(self): - return FileIdManager() + if self._file_id_manager is None: + self._file_id_manager = FileIdManager() + return self._file_id_manager def get_available_path(path: Path, sep: str = "") -> Path: @@ -268,12 +274,16 @@ def get_available_path(path: Path, sep: str = "") -> Path: def get_file_modification_time(path: Path): if path.exists(): - return datetime.utcfromtimestamp(path.stat().st_mtime).isoformat() + "Z" + return datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).isoformat().replace( + "+00:00", "Z" + ) def get_file_creation_time(path: Path): if path.exists(): - return datetime.utcfromtimestamp(path.stat().st_ctime).isoformat() + "Z" + return datetime.fromtimestamp(path.stat().st_ctime, tz=timezone.utc).isoformat().replace( + "+00:00", "Z" + ) def get_file_size(path: Path) -> Optional[int]: diff --git a/plugins/contents/pyproject.toml b/plugins/contents/pyproject.toml index 22ac019b..111ed3fa 100644 --- a/plugins/contents/pyproject.toml +++ b/plugins/contents/pyproject.toml @@ -9,7 +9,7 @@ keywords = ["jupyter", "server", "fastapi", "plugins"] requires-python = ">=3.8" dependencies = [ "watchfiles >=0.18.1,<1", - "aiosqlite >=0.17.0,<1", + "sqlite-anyio >=0.2.0,<0.3.0", "anyio>=3.6.2,<5", "jupyverse-api >=0.1.2,<1", ] diff --git a/plugins/frontend/fps_frontend/main.py b/plugins/frontend/fps_frontend/main.py index b34b488b..1b502de3 100644 --- a/plugins/frontend/fps_frontend/main.py +++ b/plugins/frontend/fps_frontend/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource from jupyverse_api.frontend import FrontendConfig @@ -7,8 +7,5 @@ class FrontendComponent(Component): def __init__(self, **kwargs): self.frontend_config = FrontendConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.frontend_config, types=FrontendConfig) + async def start(self) -> None: + add_resource(self.frontend_config, types=FrontendConfig) diff --git a/plugins/jupyterlab/fps_jupyterlab/main.py b/plugins/jupyterlab/fps_jupyterlab/main.py index 2cd31dd9..31865b8b 100644 --- a/plugins/jupyterlab/fps_jupyterlab/main.py +++ b/plugins/jupyterlab/fps_jupyterlab/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -13,16 +13,13 @@ class JupyterLabComponent(Component): def __init__(self, **kwargs): self.jupyterlab_config = JupyterLabConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - ctx.add_resource(self.jupyterlab_config, types=JupyterLabConfig) + async def start(self) -> None: + add_resource(self.jupyterlab_config, types=JupyterLabConfig) - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - lab = await ctx.request_resource(Lab) # type: ignore + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + lab = await get_resource(Lab, wait=True) # type: ignore[type-abstract] jupyterlab = _JupyterLab(app, self.jupyterlab_config, auth, frontend_config, lab) - ctx.add_resource(jupyterlab, types=JupyterLab) + add_resource(jupyterlab, types=JupyterLab) diff --git a/plugins/kernels/fps_kernels/kernel_driver/connect.py b/plugins/kernels/fps_kernels/kernel_driver/connect.py index 8c177a89..0eb4aaaa 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/connect.py +++ b/plugins/kernels/fps_kernels/kernel_driver/connect.py @@ -1,14 +1,17 @@ -import asyncio +from __future__ import annotations + import json import os import socket +import subprocess import sys import tempfile import uuid from typing import Dict, Optional, Tuple, Union import zmq -import zmq.asyncio +from anyio import open_process +from anyio.abc import Process from zmq.asyncio import Socket channel_socket_types = { @@ -71,8 +74,8 @@ def read_connection_file(fname: str) -> cfg_t: async def launch_kernel( - kernelspec_path: str, connection_file_path: str, kernel_cwd: str, capture_output: bool -) -> asyncio.subprocess.Process: + kernelspec_path: str, connection_file_path: str, kernel_cwd: str | None, capture_output: bool +) -> Process: with open(kernelspec_path) as f: kernelspec = json.load(f) cmd = [s.format(connection_file=connection_file_path) for s in kernelspec["argv"]] @@ -82,18 +85,16 @@ async def launch_kernel( "python%i.%i" % sys.version_info[:2], }: cmd[0] = sys.executable - if kernel_cwd: - prev_dir = os.getcwd() - os.chdir(kernel_cwd) if capture_output: - p = await asyncio.create_subprocess_exec( - *cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.STDOUT - ) + stdout = subprocess.DEVNULL + stderr = subprocess.STDOUT else: - p = await asyncio.create_subprocess_exec(*cmd) - if kernel_cwd: - os.chdir(prev_dir) - return p + stdout = None + stderr = None + if not kernel_cwd: + kernel_cwd = None + process = await open_process(cmd, stdout=stdout, stderr=stderr, cwd=kernel_cwd) + return process def create_socket(channel: str, cfg: cfg_t, identity: Optional[bytes] = None) -> Socket: diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index 6c2fba20..362bbbfc 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -1,9 +1,17 @@ -import asyncio -import os import time import uuid -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, Optional, cast +import anyio +from anyio import ( + TASK_STATUS_IGNORED, + Event, + create_memory_object_stream, + create_task_group, + fail_after, +) +from anyio.abc import TaskGroup, TaskStatus +from anyio.streams.stapled import StapledObjectStream from pycrdt import Array, Map from jupyverse_api.yjs import Yjs @@ -19,6 +27,8 @@ def deadline_to_timeout(deadline: float) -> float: class KernelDriver: + task_group: TaskGroup + def __init__( self, kernel_name: str = "", @@ -29,6 +39,7 @@ def __init__( capture_kernel_output: bool = True, yjs: Optional[Yjs] = None, ) -> None: + self.write_connection_file = write_connection_file self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name) self.kernel_cwd = kernel_cwd @@ -43,40 +54,56 @@ def __init__( self.key = cast(str, self.connection_cfg["key"]) self.session_id = uuid.uuid4().hex self.msg_cnt = 0 - self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {} - self.comm_messages: asyncio.Queue = asyncio.Queue() - self.tasks: List[asyncio.Task] = [] + self.execute_requests: Dict[str, Dict[str, StapledObjectStream]] = {} + self.comm_messages: StapledObjectStream = StapledObjectStream( + *create_memory_object_stream[dict](max_buffer_size=1024) + ) + self.stopped_event = Event() async def restart(self, startup_timeout: float = float("inf")) -> None: - for task in self.tasks: - task.cancel() - msg = create_message("shutdown_request", content={"restart": True}) - await send_message(msg, self.control_channel, self.key, change_date_to_str=True) - while True: - msg = cast( - Dict[str, Any], await receive_message(self.control_channel, change_str_to_date=True) - ) - if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: - break - await self._wait_for_ready(startup_timeout) - self.tasks = [] - self.listen_channels() + self.task_group.cancel_scope.cancel() + await self.stopped_event.wait() + self.stopped_event = Event() + async with create_task_group() as tg: + self.task_group = tg + msg = create_message("shutdown_request", content={"restart": True}) + await send_message(msg, self.control_channel, self.key, change_date_to_str=True) + while True: + msg = cast( + Dict[str, Any], + await receive_message(self.control_channel, change_str_to_date=True), + ) + if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: + break + await self._wait_for_ready(startup_timeout) + self.listen_channels() + tg.start_soon(self._handle_comms) - async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None: - self.kernel_process = await launch_kernel( - self.kernelspec_path, - self.connection_file_path, - self.kernel_cwd, - self.capture_kernel_output, - ) - if connect: - await self.connect(startup_timeout) + async def start( + self, + startup_timeout: float = float("inf"), + connect: bool = True, + *, + task_status: TaskStatus[None] = TASK_STATUS_IGNORED, + ) -> None: + async with create_task_group() as tg: + self.task_group = tg + self.kernel_process = await launch_kernel( + self.kernelspec_path, + self.connection_file_path, + self.kernel_cwd, + self.capture_kernel_output, + ) + if connect: + await self.connect() + task_status.started() + self.stopped_event.set() async def connect(self, startup_timeout: float = float("inf")) -> None: self.connect_channels() await self._wait_for_ready(startup_timeout) self.listen_channels() - self.tasks.append(asyncio.create_task(self._handle_comms())) + self.task_group.start_soon(self._handle_comms) def connect_channels(self, connection_cfg: Optional[cfg_t] = None): connection_cfg = connection_cfg or self.connection_cfg @@ -85,31 +112,38 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None): self.iopub_channel = connect_channel("iopub", connection_cfg) def listen_channels(self): - self.tasks.append(asyncio.create_task(self.listen_iopub())) - self.tasks.append(asyncio.create_task(self.listen_shell())) + (self.task_group.start_soon(self.listen_iopub),) + (self.task_group.start_soon(self.listen_shell),) async def stop(self) -> None: - self.kernel_process.kill() + try: + self.kernel_process.terminate() + except ProcessLookupError: + pass await self.kernel_process.wait() - os.remove(self.connection_file_path) - for task in self.tasks: - task.cancel() + self.task_group.cancel_scope.cancel() + if self.write_connection_file: + path = anyio.Path(self.connection_file_path) + try: + await path.unlink() + except Exception: + pass async def listen_iopub(self): while True: msg = await receive_message(self.iopub_channel, change_str_to_date=True) parent_id = msg["parent_header"].get("msg_id") if msg["msg_type"] in ("comm_open", "comm_msg"): - self.comm_messages.put_nowait(msg) + await self.comm_messages.send(msg) elif parent_id in self.execute_requests.keys(): - self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg) + await self.execute_requests[parent_id]["iopub_msg"].send(msg) async def listen_shell(self): while True: msg = await receive_message(self.shell_channel, change_str_to_date=True) msg_id = msg["parent_header"].get("msg_id") if msg_id in self.execute_requests.keys(): - self.execute_requests[msg_id]["shell_msg"].put_nowait(msg) + await self.execute_requests[msg_id]["shell_msg"].send(msg) async def execute( self, @@ -132,32 +166,32 @@ async def execute( self.msg_cnt += 1 await send_message(msg, self.shell_channel, self.key, change_date_to_str=True) self.execute_requests[msg_id] = { - "iopub_msg": asyncio.Queue(), - "shell_msg": asyncio.Queue(), + "iopub_msg": StapledObjectStream( + *create_memory_object_stream[dict](max_buffer_size=1024) + ), + "shell_msg": StapledObjectStream( + *create_memory_object_stream[dict](max_buffer_size=1024) + ), } if wait_for_executed: deadline = time.time() + timeout while True: try: - msg = await asyncio.wait_for( - self.execute_requests[msg_id]["iopub_msg"].get(), - deadline_to_timeout(deadline), - ) - except asyncio.TimeoutError: + with fail_after(deadline_to_timeout(deadline)): + msg = await self.execute_requests[msg_id]["iopub_msg"].receive() + except TimeoutError: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) await self._handle_outputs(ycell["outputs"], msg) if ( - (msg["header"]["msg_type"] == "status" - and msg["content"]["execution_state"] == "idle") + msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle" ): break try: - msg = await asyncio.wait_for( - self.execute_requests[msg_id]["shell_msg"].get(), - deadline_to_timeout(deadline), - ) - except asyncio.TimeoutError: + with fail_after(deadline_to_timeout(deadline)): + msg = await self.execute_requests[msg_id]["shell_msg"].receive() + except TimeoutError: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) with ycell.doc.transaction(): @@ -165,31 +199,32 @@ async def execute( ycell["execution_state"] = "idle" del self.execute_requests[msg_id] else: - self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell))) + self.task_group.start_soon(lambda: self._handle_iopub(msg_id, ycell)) async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: while True: - msg = await self.execute_requests[msg_id]["iopub_msg"].get() + msg = await self.execute_requests[msg_id]["iopub_msg"].receive() await self._handle_outputs(ycell["outputs"], msg) if ( - (msg["header"]["msg_type"] == "status" - and msg["content"]["execution_state"] == "idle") + msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle" ): - msg = await self.execute_requests[msg_id]["shell_msg"].get() + msg = await self.execute_requests[msg_id]["shell_msg"].receive() with ycell.doc.transaction(): ycell["execution_count"] = msg["content"]["execution_count"] ycell["execution_state"] = "idle" + break async def _handle_comms(self) -> None: if self.yjs is None or self.yjs.widgets is None: # type: ignore return while True: - msg = await self.comm_messages.get() + msg = await self.comm_messages.receive() msg_type = msg["header"]["msg_type"] if msg_type == "comm_open": comm_id = msg["content"]["comm_id"] - comm = Comm(comm_id, self.shell_channel, self.session_id, self.key) + comm = Comm(comm_id, self.shell_channel, self.session_id, self.key, self.task_group) self.yjs.widgets.comm_open(msg, comm) # type: ignore elif msg_type == "comm_msg": self.yjs.widgets.comm_msg(msg) # type: ignore @@ -225,13 +260,13 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): # TODO: uncomment when changes are made in jupyter-ydoc if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore outputs.append( - #Map( + # Map( # { # "name": content["name"], # "output_type": msg_type, # "text": Array([content["text"]]), # } - #) + # ) { "name": content["name"], "output_type": msg_type, @@ -239,7 +274,7 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): } ) else: - #outputs[-1]["text"].append(content["text"]) # type: ignore + # outputs[-1]["text"].append(content["text"]) # type: ignore last_output = outputs[-1] last_output["text"].append(content["text"]) # type: ignore outputs[-1] = last_output @@ -274,11 +309,14 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): class Comm: - def __init__(self, comm_id: str, shell_channel, session_id: str, key: str): + def __init__( + self, comm_id: str, shell_channel, session_id: str, key: str, task_group: TaskGroup + ): self.comm_id = comm_id self.shell_channel = shell_channel self.session_id = session_id self.key = key + self.task_group = task_group self.msg_cnt = 0 def send(self, buffers): @@ -290,6 +328,6 @@ def send(self, buffers): buffers=buffers, ) self.msg_cnt += 1 - asyncio.create_task( - send_message(msg, self.shell_channel, self.key, change_date_to_str=True) + self.task_group.start_soon( + lambda: send_message(msg, self.shell_channel, self.key, change_date_to_str=True) ) diff --git a/plugins/kernels/fps_kernels/kernel_driver/message.py b/plugins/kernels/fps_kernels/kernel_driver/message.py index 6946c73c..5905a46c 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/message.py +++ b/plugins/kernels/fps_kernels/kernel_driver/message.py @@ -32,7 +32,7 @@ def date_to_str(obj: Dict[str, Any]): def utcnow() -> datetime: - return datetime.utcnow().replace(tzinfo=timezone.utc) + return datetime.now(tz=timezone.utc) def create_message_header(msg_type: str, session_id: str, msg_id: str) -> Dict[str, Any]: diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index e10a2b84..0336136c 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -1,11 +1,12 @@ -import asyncio import json -import os import signal import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Dict, Iterable, List, Optional, cast +import anyio +from anyio import TASK_STATUS_IGNORED, Event, create_task_group +from anyio.abc import TaskGroup, TaskStatus from fastapi import WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState @@ -62,7 +63,6 @@ def __init__( self.connection_cfg = connection_cfg self.connection_file = connection_file self.write_connection_file = write_connection_file - self.channel_tasks: List[asyncio.Task] = [] self.sessions: Dict[str, AcceptedWebSocket] = {} # blocked messages and allowed messages are mutually exclusive self.blocked_messages: List[str] = [] @@ -104,77 +104,88 @@ def allow_messages(self, message_types: Optional[Iterable[str]] = None): def connections(self) -> int: return len(self.sessions) - async def start(self, launch_kernel: bool = True) -> None: - self.last_activity = { - "date": datetime.utcnow().isoformat() + "Z", - "execution_state": "starting", - } - if launch_kernel: - if not self.kernelspec_path: - raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") - self.kernel_process = await _launch_kernel( - self.kernelspec_path, - self.connection_file_path, - self.kernel_cwd, - self.capture_kernel_output, + async def start( + self, launch_kernel: bool = True, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED + ) -> None: + async with create_task_group() as tg: + self.task_group = tg + self.last_activity = { + "date": datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z"), + "execution_state": "starting", + } + if launch_kernel: + if not self.kernelspec_path: + raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") + self.kernel_process = await _launch_kernel( + self.kernelspec_path, + self.connection_file_path, + self.kernel_cwd, + self.capture_kernel_output, + ) + assert self.connection_cfg is not None + identity = uuid.uuid4().hex.encode("ascii") + self.shell_channel = connect_channel("shell", self.connection_cfg, identity=identity) + self.stdin_channel = connect_channel("stdin", self.connection_cfg, identity=identity) + self.control_channel = connect_channel( + "control", self.connection_cfg, identity=identity ) - assert self.connection_cfg is not None - identity = uuid.uuid4().hex.encode("ascii") - self.shell_channel = connect_channel("shell", self.connection_cfg, identity=identity) - self.stdin_channel = connect_channel("stdin", self.connection_cfg, identity=identity) - self.control_channel = connect_channel("control", self.connection_cfg, identity=identity) - self.iopub_channel = connect_channel("iopub", self.connection_cfg) - await self._wait_for_ready() - self.channel_tasks += [ - asyncio.create_task(self.listen("shell")), - asyncio.create_task(self.listen("stdin")), - asyncio.create_task(self.listen("control")), - asyncio.create_task(self.listen("iopub")), - ] + self.iopub_channel = connect_channel("iopub", self.connection_cfg) + await self._wait_for_ready() + tg.start_soon(lambda: self.listen("shell")) + tg.start_soon(lambda: self.listen("stdin")) + tg.start_soon(lambda: self.listen("control")) + tg.start_soon(lambda: self.listen("iopub")) + task_status.started() async def stop(self) -> None: + try: + self.kernel_process.terminate() + except ProcessLookupError: + pass + await self.kernel_process.wait() + self.task_group.cancel_scope.cancel() if self.write_connection_file: - # FIXME: stop kernel in a better way + path = anyio.Path(self.connection_file_path) try: - self.kernel_process.send_signal(signal.SIGINT) - self.kernel_process.kill() - await self.kernel_process.wait() - except BaseException: + await path.unlink() + except Exception: pass - try: - os.remove(self.connection_file_path) - except BaseException: - pass - for task in self.channel_tasks: - task.cancel() - self.channel_tasks = [] def interrupt(self) -> None: self.kernel_process.send_signal(signal.SIGINT) - async def restart(self) -> None: + async def restart(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: await self.stop() self.setup_connection_file() - await self.start() + await self.start(task_status=task_status) async def serve( self, websocket: AcceptedWebSocket, session_id: str, permissions: Optional[Dict[str, List[str]]], + stop_event: Event, ): self.sessions[session_id] = websocket self.can_execute = permissions is None or "execute" in permissions.get("kernels", []) - await self.listen_web(websocket) + async with create_task_group() as tg: + tg.start_soon(self.listen_web, websocket, tg) + tg.start_soon(self._watch_stop, tg, stop_event) + # the session could have been removed through the REST API, so check if it still exists if session_id in self.sessions: del self.sessions[session_id] - async def listen_web(self, websocket: AcceptedWebSocket): + async def _watch_stop(self, tg: TaskGroup, stop_event: Event): + await stop_event.wait() + tg.cancel_scope.cancel() + + async def listen_web(self, websocket: AcceptedWebSocket, tg: TaskGroup): try: await self.send_to_zmq(websocket) except WebSocketDisconnect: pass + tg.cancel_scope.cancel() async def listen(self, channel_name: str): if channel_name == "shell": @@ -264,7 +275,7 @@ async def send_to_ws(self, websocket, parts, parent_header, channel_name): "execution_state": msg["content"]["execution_state"], } elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org": - bin_msg = serialize_msg_to_ws_v1(parts, channel_name) + bin_msg = b"".join(serialize_msg_to_ws_v1(parts, channel_name)) try: await websocket.websocket.send_bytes(bin_msg) except BaseException: diff --git a/plugins/kernels/fps_kernels/main.py b/plugins/kernels/fps_kernels/main.py index 3cb4abc2..fb5a3a41 100644 --- a/plugins/kernels/fps_kernels/main.py +++ b/plugins/kernels/fps_kernels/main.py @@ -1,19 +1,14 @@ from __future__ import annotations -import asyncio -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Optional - -from asphalt.core import Component, Context, context_teardown +from asphalt.core import Component, add_resource, get_resource, start_service_task from jupyverse_api.app import App from jupyverse_api.auth import Auth from jupyverse_api.frontend import FrontendConfig from jupyverse_api.kernels import Kernels, KernelsConfig +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs -from .kernel_driver.paths import jupyter_runtime_dir from .routes import _Kernels @@ -21,36 +16,18 @@ class KernelsComponent(Component): def __init__(self, **kwargs): self.kernels_config = KernelsConfig(**kwargs) - @context_teardown - async def start( - self, - ctx: Context, - ) -> AsyncGenerator[None, Optional[BaseException]]: - ctx.add_resource(self.kernels_config, types=KernelsConfig) - - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - yjs = ( - await ctx.request_resource(Yjs) # type: ignore - if self.kernels_config.require_yjs - else None - ) - - kernels = _Kernels(app, self.kernels_config, auth, frontend_config, yjs) - ctx.add_resource(kernels, types=Kernels) - - if self.kernels_config.allow_external_kernels: - external_connection_dir = self.kernels_config.external_connection_dir - if external_connection_dir is None: - path = Path(jupyter_runtime_dir()) / "external_kernels" - else: - path = Path(external_connection_dir) - task = asyncio.create_task(kernels.watch_connection_files(path)) - - yield - - if self.kernels_config.allow_external_kernels: - task.cancel() - for kernel in kernels.kernels.values(): - await kernel["server"].stop() + async def start(self) -> None: + add_resource(self.kernels_config, types=KernelsConfig) + + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + lifespan = await get_resource(Lifespan, wait=True) + if self.kernels_config.require_yjs: + yjs = await get_resource(Yjs, wait=True) # type: ignore[type-abstract] + else: + yjs = None + + kernels = _Kernels(app, self.kernels_config, auth, frontend_config, yjs, lifespan) # type: ignore[type-abstract] + await start_service_task(kernels.start, "Kernels", teardown_action=kernels.stop) + add_resource(kernels, types=Kernels) diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index e15eab80..77e475d5 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -1,10 +1,13 @@ import json import logging import uuid +from functools import partial from http import HTTPStatus from pathlib import Path from typing import Dict, List, Optional, Set, Tuple +from anyio import TASK_STATUS_IGNORED, Event, create_task_group +from anyio.abc import TaskStatus from fastapi import HTTPException, Response from fastapi.responses import FileResponse from starlette.requests import Request @@ -15,10 +18,12 @@ from jupyverse_api.frontend import FrontendConfig from jupyverse_api.kernels import Kernels, KernelsConfig from jupyverse_api.kernels.models import CreateSession, Execution, Kernel, Notebook, Session +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs from .kernel_driver.driver import KernelDriver from .kernel_driver.kernelspec import find_kernelspec, kernelspec_dirs +from .kernel_driver.paths import jupyter_runtime_dir from .kernel_server.server import ( AcceptedWebSocket, KernelServer, @@ -36,17 +41,50 @@ def __init__( auth: Auth, frontend_config: FrontendConfig, yjs: Optional[Yjs], + lifespan: Lifespan, ) -> None: super().__init__(app=app, auth=auth) self.kernels_config = kernels_config self.frontend_config = frontend_config self.yjs = yjs + self.lifespan = lifespan self.kernelspecs: dict = {} self.kernel_id_to_connection_file: Dict[str, str] = {} self.sessions: Dict[str, Session] = {} self.kernels = kernels self._app = app + self.stop_event = Event() + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + async with create_task_group() as tg: + self.task_group = tg + if self.kernels_config.allow_external_kernels: + external_connection_dir = self.kernels_config.external_connection_dir + if external_connection_dir is None: + path = Path(jupyter_runtime_dir()) / "external_kernels" + else: + path = Path(external_connection_dir) + tg.start_soon(lambda: self.watch_connection_files(path)) + tg.start_soon(self.on_shutdown) + task_status.started() + await self.stop_event.wait() + + async def stop(self) -> None: + if self.stop_event.is_set(): + return + + async with create_task_group(): + for kernel in self.kernels.values(): + self.task_group.start_soon(kernel["server"].stop) + if kernel["driver"] is not None: + self.task_group.start_soon(kernel["driver"].stop) + self.stop_event.set() + self.task_group.cancel_scope.cancel() + + async def on_shutdown(self): + await self.lifespan.shutdown_request.wait() + await self.stop() async def get_status( self, @@ -179,7 +217,7 @@ async def create_session( ) kernel_id = str(uuid.uuid4()) kernels[kernel_id] = {"name": kernel_name, "server": kernel_server, "driver": None} - await kernel_server.start() + await self.task_group.start(kernel_server.start) elif kernel_id is not None: # external kernel kernel_name = kernels[kernel_id]["name"] @@ -188,7 +226,7 @@ async def create_session( write_connection_file=False, ) kernels[kernel_id]["server"] = kernel_server - await kernel_server.start(launch_kernel=False) + await self.task_group.start(partial(kernel_server.start, launch_kernel=False)) else: return session_id = str(uuid.uuid4()) @@ -236,7 +274,7 @@ async def restart_kernel( ): if kernel_id in kernels: kernel = kernels[kernel_id] - await kernel["server"].restart() + await self.task_group.start(kernel["server"].restart) result = { "id": kernel_id, "name": kernel["name"], @@ -274,7 +312,7 @@ async def execute_cell( connection_file=kernel["server"].connection_file_path, yjs=self.yjs, ) - await driver.connect() + await self.task_group.start(driver.start) driver = kernel["driver"] await driver.execute(ycell, wait_for_executed=False) @@ -332,16 +370,18 @@ async def kernel_channels( connection_file=self.kernel_id_to_connection_file[kernel_id], write_connection_file=False, ) - await kernel_server.start(launch_kernel=False) + await self.task_group.start(partial(kernel_server.start, launch_kernel=False)) kernels[kernel_id]["server"] = kernel_server - await kernel_server.serve(accepted_websocket, session_id, permissions) + await kernel_server.serve( + accepted_websocket, session_id, permissions, self.lifespan.shutdown_request + ) async def watch_connection_files(self, path: Path) -> None: # first time scan, treat everything as added files initial_changes = {(Change.added, str(p)) for p in path.iterdir()} await self.process_connection_files(initial_changes) # then, on every change - async for changes in awatch(path): + async for changes in awatch(path, stop_event=self.stop_event): await self.process_connection_files(changes) async def process_connection_files(self, changes: Set[Tuple[Change, str]]): diff --git a/plugins/kernels/pyproject.toml b/plugins/kernels/pyproject.toml index 9a11f90e..4340804b 100644 --- a/plugins/kernels/pyproject.toml +++ b/plugins/kernels/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "types-python-dateutil", "watchfiles >=0.16.1,<1", "jupyverse-api >=0.1.2,<1", + "anyio", ] dynamic = [ "version",] [[project.authors]] diff --git a/plugins/lab/fps_lab/main.py b/plugins/lab/fps_lab/main.py index 50292912..2628bab2 100644 --- a/plugins/lab/fps_lab/main.py +++ b/plugins/lab/fps_lab/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -10,14 +10,11 @@ class LabComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - jupyterlab_config = ctx.get_resource(JupyterLabConfig) + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + jupyterlab_config = await get_resource(JupyterLabConfig, optional=True) lab = _Lab(app, auth, frontend_config, jupyterlab_config) - ctx.add_resource(lab, types=Lab) + add_resource(lab, types=Lab) diff --git a/plugins/lab/fps_lab/routes.py b/plugins/lab/fps_lab/routes.py index 3f68f844..11667a35 100644 --- a/plugins/lab/fps_lab/routes.py +++ b/plugins/lab/fps_lab/routes.py @@ -7,6 +7,11 @@ from pathlib import Path from typing import List, Optional, Tuple +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + import json5 # type: ignore from babel import Locale from fastapi import Response, status diff --git a/plugins/login/fps_login/main.py b/plugins/login/fps_login/main.py index f513ca57..95550e46 100644 --- a/plugins/login/fps_login/main.py +++ b/plugins/login/fps_login/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import AuthConfig @@ -8,12 +8,9 @@ class LoginComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth_config = await ctx.request_resource(AuthConfig) + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth_config = await get_resource(AuthConfig, wait=True) login = _Login(app, auth_config) - ctx.add_resource(login, types=Login) + add_resource(login, types=Login) diff --git a/plugins/nbconvert/fps_nbconvert/main.py b/plugins/nbconvert/fps_nbconvert/main.py index c865adf8..5b48d53c 100644 --- a/plugins/nbconvert/fps_nbconvert/main.py +++ b/plugins/nbconvert/fps_nbconvert/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -8,12 +8,9 @@ class NbconvertComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] nbconvert = _Nbconvert(app, auth) - ctx.add_resource(nbconvert, types=Nbconvert) + add_resource(nbconvert, types=Nbconvert) diff --git a/plugins/noauth/fps_noauth/main.py b/plugins/noauth/fps_noauth/main.py index 911a79f8..92bf0846 100644 --- a/plugins/noauth/fps_noauth/main.py +++ b/plugins/noauth/fps_noauth/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource from jupyverse_api.auth import Auth @@ -6,9 +6,6 @@ class NoAuthComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: + async def start(self) -> None: no_auth = _NoAuth() - ctx.add_resource(no_auth, types=Auth) + add_resource(no_auth, types=Auth) diff --git a/plugins/notebook/fps_notebook/main.py b/plugins/notebook/fps_notebook/main.py index 31521a06..a4b9e544 100644 --- a/plugins/notebook/fps_notebook/main.py +++ b/plugins/notebook/fps_notebook/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -10,14 +10,11 @@ class NotebookComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - frontend_config = await ctx.request_resource(FrontendConfig) - lab = await ctx.request_resource(Lab) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + frontend_config = await get_resource(FrontendConfig, wait=True) + lab = await get_resource(Lab, wait=True) # type: ignore[type-abstract] notebook = _Notebook(app, auth, frontend_config, lab) - ctx.add_resource(notebook, types=Notebook) + add_resource(notebook, types=Notebook) diff --git a/plugins/resource_usage/fps_resource_usage/main.py b/plugins/resource_usage/fps_resource_usage/main.py index 4cc3c2b7..14b64669 100644 --- a/plugins/resource_usage/fps_resource_usage/main.py +++ b/plugins/resource_usage/fps_resource_usage/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -11,12 +11,9 @@ class ResourceUsageComponent(Component): def __init__(self, **kwargs): self.resource_usage_config = ResourceUsageConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] resource_usage = _ResourceUsage(app, auth, self.resource_usage_config) - ctx.add_resource(resource_usage, types=ResourceUsage) + add_resource(resource_usage, types=ResourceUsage) diff --git a/plugins/terminals/fps_terminals/main.py b/plugins/terminals/fps_terminals/main.py index 93a00719..e56373ff 100644 --- a/plugins/terminals/fps_terminals/main.py +++ b/plugins/terminals/fps_terminals/main.py @@ -1,7 +1,7 @@ -import os +import sys from typing import Type -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource, start_service_task from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -10,19 +10,17 @@ from .routes import _Terminals _TerminalServer: Type[TerminalServer] -if os.name == "nt": +if sys.platform == "win32": from .win_server import _TerminalServer else: from .server import _TerminalServer class TerminalsComponent(Component): - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] terminals = _Terminals(app, auth, _TerminalServer) - ctx.add_resource(terminals, types=Terminals) + await start_service_task(terminals.start, name="Terminals", teardown_action=terminals.stop) + add_resource(terminals, types=Terminals) diff --git a/plugins/terminals/fps_terminals/routes.py b/plugins/terminals/fps_terminals/routes.py index ca7d2f8d..337c07fb 100644 --- a/plugins/terminals/fps_terminals/routes.py +++ b/plugins/terminals/fps_terminals/routes.py @@ -1,7 +1,8 @@ -from datetime import datetime +from datetime import datetime, timezone from http import HTTPStatus from typing import Any, Dict, Type +from anyio import Event, create_task_group from fastapi import Response from jupyverse_api.app import App @@ -15,6 +16,16 @@ class _Terminals(Terminals): def __init__(self, app: App, auth: Auth, _TerminalServer: Type[TerminalServer]) -> None: super().__init__(app=app, auth=auth) self.TerminalServer = _TerminalServer + self.stop_event = Event() + + async def start(self): + await self.stop_event.wait() + + async def stop(self): + async with create_task_group() as tg: + for terminal in TERMINALS.values(): + tg.start_soon(terminal["server"].stop) + self.stop_event.set() async def get_terminals( self, @@ -29,7 +40,7 @@ async def create_terminal( name = str(len(TERMINALS) + 1) terminal = Terminal( name=name, - last_activity=datetime.utcnow().isoformat() + "Z", + last_activity=datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z"), ) server = self.TerminalServer() TERMINALS[name] = {"info": terminal, "server": server} diff --git a/plugins/terminals/fps_terminals/server.py b/plugins/terminals/fps_terminals/server.py index c086f8c6..bbdb55b2 100644 --- a/plugins/terminals/fps_terminals/server.py +++ b/plugins/terminals/fps_terminals/server.py @@ -1,75 +1,140 @@ -import asyncio import fcntl import os import pty +import selectors import shlex import struct import termios +from functools import partial +from anyio import create_memory_object_stream, create_task_group, from_thread, to_thread +from anyio.abc import ByteReceiveStream, ByteSendStream from fastapi import WebSocketDisconnect from jupyverse_api.terminals import TerminalServer -def open_terminal(command="bash", columns=80, lines=24): - pid, fd = pty.fork() - if pid == 0: - argv = shlex.split(command) - env = os.environ.copy() - env.update(TERM="linux", COLUMNS=str(columns), LINES=str(lines)) - os.execvpe(argv[0], argv, env) - return fd - - class _TerminalServer(TerminalServer): def __init__(self): - self.fd = open_terminal() + # FIXME: pass in config + command = "bash" + columns = 80 + lines = 24 + + pid, fd = pty.fork() + if pid == 0: + argv = shlex.split(command) + env = os.environ.copy() + env.update(TERM="linux", COLUMNS=str(columns), LINES=str(lines)) + os.execvpe(argv[0], argv, env) + self.fd = fd self.p_out = os.fdopen(self.fd, "w+b", 0) self.websockets = [] - async def serve(self, websocket, permissions): + async def serve(self, websocket, permissions) -> None: self.websocket = websocket + self.permissions = permissions self.websockets.append(websocket) - self.event = asyncio.Event() - self.loop = asyncio.get_event_loop() - - task = asyncio.create_task(self.send_data()) - - def on_output(): - try: - self.data_or_disconnect = self.p_out.read(65536).decode() - self.event.set() - except Exception: - self.loop.remove_reader(self.p_out) - self.data_or_disconnect = None - self.event.set() - - self.loop.add_reader(self.p_out, on_output) - await websocket.send_json(["setup", {}]) - can_execute = permissions is None or "execute" in permissions.get("terminals", []) + + async with create_task_group() as self.task_group: + self.recv_stream = ReceiveStream(self.p_out, self.task_group) + self.send_stream = SendStream(self.p_out) + self.task_group.start_soon(self.backend_to_frontend) + self.task_group.start_soon(self.frontend_to_backend) + + async def stop(self) -> None: + os.write(self.recv_stream.pipeout, b"0") + self.p_out.close() + try: + self.recv_stream.sel.unregister(self.p_out) + except Exception: + pass + self.task_group.cancel_scope.cancel() + + async def backend_to_frontend(self): + while True: + data = (await self.recv_stream.receive(65536)).decode() + for websocket in self.websockets: + await websocket.send_json(["stdout", data]) + + async def frontend_to_backend(self): + await self.websocket.send_json(["setup", {}]) + can_execute = self.permissions is None or "execute" in self.permissions.get("terminals", []) try: while True: - msg = await websocket.receive_json() + msg = await self.websocket.receive_json() if can_execute: if msg[0] == "stdin": - self.p_out.write(msg[1].encode()) + await self.send_stream.send(msg[1].encode()) elif msg[0] == "set_size": winsize = struct.pack("HH", msg[1], msg[2]) fcntl.ioctl(self.fd, termios.TIOCSWINSZ, winsize) except WebSocketDisconnect: - task.cancel() - - async def send_data(self): - while True: - await self.event.wait() - self.event.clear() - if self.data_or_disconnect is None: - await self.websocket.send_json(["disconnect", 1]) - else: - for websocket in self.websockets: - await websocket.send_json(["stdout", self.data_or_disconnect]) + self.quit(self.websocket) + self.task_group.cancel_scope.cancel() def quit(self, websocket): - self.websockets.remove(websocket) - if not self.websockets: - os.close(self.fd) + try: + os.write(self.recv_stream.pipeout, b"0") + self.p_out.close() + self.recv_stream.sel.unregister(self.p_out) + self.websockets.remove(websocket) + if not self.websockets: + os.close(self.fd) + except Exception: + pass + + +class ReceiveStream(ByteReceiveStream): + def __init__(self, p_out, task_group): + self.p_out = p_out + self.sel = selectors.DefaultSelector() + self.sel.register(self.p_out, selectors.EVENT_READ, self._read) + self.pipein, self.pipeout = os.pipe() + f = os.fdopen(self.pipein, "r+b", 0) + + def cb(): + return True + + self.sel.register(f, selectors.EVENT_READ, cb) + self.send_stream, self.recv_stream = create_memory_object_stream[bytes]( + max_buffer_size=65536 + ) + + def reader(): + while True: + events = self.sel.select() + for key, mask in events: + callback = key.data + if callback(): + return + + task_group.start_soon(partial(to_thread.run_sync, reader, abandon_on_cancel=True)) + + def _read(self) -> bool: + try: + data = self.p_out.read(65536) + except OSError: + self.sel.unregister(self.p_out) + return True + else: + from_thread.run_sync(self.send_stream.send_nowait, data) + return False + + async def receive(self, max_bytes: int = 65536) -> bytes: + data = await self.recv_stream.receive() + return data + + async def aclose(self) -> None: + pass + + +class SendStream(ByteSendStream): + def __init__(self, p_out): + self.p_out = p_out + + async def send(self, item: bytes) -> None: + self.p_out.write(item) + + async def aclose(self) -> None: + pass diff --git a/plugins/terminals/fps_terminals/win_server.py b/plugins/terminals/fps_terminals/win_server.py index f1865391..a71c3474 100644 --- a/plugins/terminals/fps_terminals/win_server.py +++ b/plugins/terminals/fps_terminals/win_server.py @@ -1,8 +1,8 @@ -import asyncio import os from functools import partial -from anyio import to_thread +from anyio import create_task_group, to_thread +from fastapi import WebSocketDisconnect from winpty import PTY # type: ignore from jupyverse_api.terminals import TerminalServer @@ -20,16 +20,20 @@ def __init__(self): self.process = open_terminal() self.websockets = [] - async def serve(self, websocket): + async def serve(self, websocket, permissions) -> None: self.websocket = websocket + self.permissions = permissions self.websockets.append(websocket) await websocket.send_json(["setup", {}]) - self.send_task = asyncio.create_task(self.send_data()) - self.recv_task = asyncio.create_task(self.recv_data()) + async with create_task_group() as tg: + self.task_group = tg + tg.start_soon(self.send_data) + tg.start_soon(self.recv_data) - await asyncio.gather(self.send_task, self.recv_task) + async def stop(self) -> None: + self.task_group.cancel_scope.cancel() async def send_data(self): while True: @@ -43,19 +47,21 @@ async def send_data(self): await websocket.send_json(["stdout", data]) async def recv_data(self): - while True: - try: + can_execute = self.permissions is None or "execute" in self.permissions.get("terminals", []) + try: + while True: msg = await self.websocket.receive_json() - except Exception: - return - if msg[0] == "stdin": - self.process.write(msg[1]) - elif msg[0] == "set_size": - self.process.set_size(msg[2], msg[1]) + if can_execute: + if msg[0] == "stdin": + self.process.write(msg[1]) + elif msg[0] == "set_size": + self.process.set_size(msg[2], msg[1]) + except WebSocketDisconnect: + self.quit(self.websocket) + self.task_group.cancel_scope.cancel() def quit(self, websocket): self.websockets.remove(websocket) if not self.websockets: - self.send_task.cancel() - self.recv_task.cancel() + self.task_group.cancel_scope.cancel() del self.process diff --git a/plugins/webdav/fps_webdav/main.py b/plugins/webdav/fps_webdav/main.py index 2e6c4662..7da847bd 100644 --- a/plugins/webdav/fps_webdav/main.py +++ b/plugins/webdav/fps_webdav/main.py @@ -1,4 +1,4 @@ -from asphalt.core import Component, Context +from asphalt.core import Component, add_resource, get_resource from jupyverse_api.app import App @@ -10,11 +10,8 @@ class WebDAVComponent(Component): def __init__(self, **kwargs): self.webdav_config = WebDAVConfig(**kwargs) - async def start( - self, - ctx: Context, - ) -> None: - app = await ctx.request_resource(App) + async def start(self) -> None: + app = await get_resource(App, wait=True) webdav = WebDAV(app, self.webdav_config) - ctx.add_resource(webdav) + add_resource(webdav) diff --git a/plugins/webdav/fps_webdav/routes.py b/plugins/webdav/fps_webdav/routes.py index b31b1fc9..bf67903a 100644 --- a/plugins/webdav/fps_webdav/routes.py +++ b/plugins/webdav/fps_webdav/routes.py @@ -46,7 +46,7 @@ def __init__(self, app: App, webdav_config: WebDAVConfig): for account in webdav_config.account_mapping: logger.info(f"WebDAV user {account.username} has password {account.password}") - webdav_conf = webdav_config.dict() + webdav_conf = webdav_config.model_dump() init_config_from_obj(webdav_conf) webdav_aep = AppEntryParameters() webdav_app = get_asgi_app(aep=webdav_aep, config_obj=webdav_conf) diff --git a/plugins/webdav/pyproject.toml b/plugins/webdav/pyproject.toml index 718adeee..62f23725 100644 --- a/plugins/webdav/pyproject.toml +++ b/plugins/webdav/pyproject.toml @@ -30,7 +30,6 @@ Homepage = "https://jupyter.org" test = [ "easywebdav", "pytest", - "pytest-asyncio", ] [tool.check-manifest] diff --git a/plugins/webdav/tests/conftest.py b/plugins/webdav/tests/conftest.py new file mode 100644 index 00000000..af7e4799 --- /dev/null +++ b/plugins/webdav/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture +def anyio_backend(): + return "asyncio" diff --git a/plugins/webdav/tests/test_webdav.py b/plugins/webdav/tests/test_webdav.py index dc6ed856..a4e53df4 100644 --- a/plugins/webdav/tests/test_webdav.py +++ b/plugins/webdav/tests/test_webdav.py @@ -24,17 +24,17 @@ def configure(components, config): return _components -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python >=3.10") async def test_webdav(unused_tcp_port): components = configure( COMPONENTS, {"webdav": {"account_mapping": [{"username": "foo", "password": "bar"}]}} ) - async with Context() as ctx: + async with Context(): await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() webdav = easywebdav.connect( "127.0.0.1", port=unused_tcp_port, path="webdav", username="foo", password="bar" diff --git a/plugins/yjs/fps_yjs/main.py b/plugins/yjs/fps_yjs/main.py index eacd1b91..b4cdbcc8 100644 --- a/plugins/yjs/fps_yjs/main.py +++ b/plugins/yjs/fps_yjs/main.py @@ -1,36 +1,29 @@ from __future__ import annotations -from collections.abc import AsyncGenerator -from typing import Optional - -from asphalt.core import Component, Context, context_teardown +from asphalt.core import Component, add_resource, get_resource, start_service_task from jupyverse_api.app import App from jupyverse_api.auth import Auth from jupyverse_api.contents import Contents +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs from .routes import _Yjs class YjsComponent(Component): - @context_teardown - async def start( - self, - ctx: Context, - ) -> AsyncGenerator[None, Optional[BaseException]]: - app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) # type: ignore - contents = await ctx.request_resource(Contents) # type: ignore - - yjs = _Yjs(app, auth, contents) - ctx.add_resource(yjs, types=Yjs) - - # start indexing in the background - contents.file_id_manager - - yield - - yjs.room_manager.stop() - contents.file_id_manager.stop_watching_files.set() - await contents.file_id_manager.stopped_watching_files.wait() + async def start(self) -> None: + app = await get_resource(App, wait=True) + auth = await get_resource(Auth, wait=True) # type: ignore[type-abstract] + contents = await get_resource(Contents, wait=True) # type: ignore[type-abstract] + lifespan = await get_resource(Lifespan, wait=True) + + yjs = _Yjs(app, auth, contents, lifespan) + add_resource(yjs, types=Yjs) + + await start_service_task(yjs.start, "Room manager", teardown_action=yjs.stop) + await start_service_task( + contents.file_id_manager.start, + "File ID manager", + teardown_action=contents.file_id_manager.stop, + ) diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 1a023d95..14460686 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -1,12 +1,13 @@ from __future__ import annotations -import asyncio import logging from datetime import datetime from functools import partial from typing import Dict from uuid import uuid4 +from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group, sleep +from anyio.abc import TaskGroup, TaskStatus from fastapi import ( HTTPException, Request, @@ -20,6 +21,7 @@ from jupyverse_api.app import App from jupyverse_api.auth import Auth, User from jupyverse_api.contents import Contents +from jupyverse_api.main import Lifespan from jupyverse_api.yjs import Yjs from jupyverse_api.yjs.models import CreateDocumentSession @@ -46,15 +48,32 @@ def __init__( app: App, auth: Auth, contents: Contents, + lifespan: Lifespan, ) -> None: super().__init__(app=app, auth=auth) self.contents = contents - self.room_manager = RoomManager(contents) + self.lifespan = lifespan if Widgets is None: self.widgets = None else: self.widgets = Widgets() # type: ignore + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + async with create_task_group() as tg: + self._task_group = tg + self.room_manager = RoomManager(self.contents, self.lifespan, tg) + await tg.start(self.room_manager.websocket_server.start) + tg.start_soon(self.room_manager.on_shutdown) + task_status.started() + + async def stop(self) -> None: + for watcher in self.room_manager.watchers.values(): + watcher.cancel() + for saver in self.room_manager.savers.values(): + saver.cancel() + for cleaner in self.room_manager.cleaners.values(): + cleaner.cancel() + async def collaboration_room_websocket( self, path, @@ -140,33 +159,31 @@ async def recv(self): class RoomManager: contents: Contents + lifespan: Lifespan documents: Dict[str, YBaseDoc] - watchers: Dict[str, asyncio.Task] - savers: Dict[str, asyncio.Task] - cleaners: Dict[YRoom, asyncio.Task] + watchers: Dict[str, Task] + savers: Dict[str, Task] + cleaners: Dict[YRoom, Task] last_modified: Dict[str, datetime] websocket_server: JupyterWebsocketServer - lock: asyncio.Lock + lock: Lock + _task_group: TaskGroup - def __init__(self, contents: Contents): + def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup): self.contents = contents + self.lifespan = lifespan + self._task_group = task_group self.documents = {} # a dictionary of room_name:document self.watchers = {} # a dictionary of file_id:task self.savers = {} # a dictionary of file_id:task self.cleaners = {} # a dictionary of room:task self.last_modified = {} # a dictionary of file_id:last_modification_date self.websocket_server = JupyterWebsocketServer(rooms_ready=False, auto_clean_rooms=False) - self.websocket_server_task = asyncio.create_task(self.websocket_server.start()) - self.lock = asyncio.Lock() + self.lock = Lock() - def stop(self): - for watcher in self.watchers.values(): - watcher.cancel() - for saver in self.savers.values(): - saver.cancel() - for cleaner in self.cleaners.values(): - cleaner.cancel() - self.websocket_server.stop() + async def on_shutdown(self): + await self.lifespan.shutdown_request.wait() + await self.websocket_server.stop() async def serve(self, websocket: YWebsocket, permissions) -> None: room = await self.websocket_server.get_room(websocket.path) @@ -216,16 +233,17 @@ async def serve(self, websocket: YWebsocket, permissions) -> None: ) # update the document when file changes if file_id not in self.watchers: - self.watchers[file_id] = asyncio.create_task( - self.watch_file(file_format, file_id, document) + self.watchers[file_id] = Task( + self.watch_file(file_format, file_id, document), self._task_group ) - await self.websocket_server.started.wait() - await self.websocket_server.serve(websocket) + await self.websocket_server.serve(websocket, self.lifespan.shutdown_request) if is_stored_document and not room.clients: # no client in this room after we disconnect - self.cleaners[room] = asyncio.create_task(self.maybe_clean_room(room, websocket.path)) + self.cleaners[room] = Task( + self.maybe_clean_room(room, websocket.path), self._task_group + ) async def filter_message(self, can_write: bool, message: bytes) -> bool: """ @@ -262,18 +280,18 @@ async def watch_file(self, file_format: str, file_id: str, document: YBaseDoc) - file_path = await self.get_file_path(file_id, document) assert file_path is not None logger.debug(f"Watching file: {file_path}") - while True: - watcher = self.contents.file_id_manager.watch(file_path) - async for changes in watcher: - new_file_path = await self.get_file_path(file_id, document) - if new_file_path is None: - continue - if new_file_path != file_path: - # file was renamed - self.contents.file_id_manager.unwatch(file_path, watcher) - file_path = new_file_path - # break - await self.maybe_load_file(file_format, file_path, file_id) + # FIXME: handle file rename/move? + watcher = self.contents.file_id_manager.watch(file_path) + async for changes in watcher: + new_file_path = await self.get_file_path(file_id, document) + if new_file_path is None: + continue + if new_file_path != file_path: + # file was renamed + self.contents.file_id_manager.unwatch(file_path, watcher) + file_path = new_file_path + # break + await self.maybe_load_file(file_format, file_path, file_id) async def maybe_load_file(self, file_format: str, file_path: str, file_id: str) -> None: async with self.lock: @@ -306,15 +324,15 @@ def on_document_change( ) if file_id in self.savers: self.savers[file_id].cancel() - self.savers[file_id] = asyncio.create_task( - self.maybe_save_document(file_id, file_type, file_format, document) + self.savers[file_id] = Task( + self.maybe_save_document(file_id, file_type, file_format, document), self._task_group ) async def maybe_save_document( self, file_id: str, file_type: str, file_format: str, document: YBaseDoc ) -> None: # save after 1 second of inactivity to prevent too frequent saving - await asyncio.sleep(1) # FIXME: pass in config + await sleep(1) # FIXME: pass in config # if the room cannot be found, don't save try: file_path = await self.get_file_path(file_id, document) @@ -351,7 +369,7 @@ async def maybe_save_document( async def maybe_clean_room(self, room, ws_path: str) -> None: file_id = ws_path.split(":", 2)[2] # keep the document for a while in case someone reconnects - await asyncio.sleep(60) # FIXME: pass in config + await sleep(60) # FIXME: pass in config document = self.documents[ws_path] document.unobserve() del self.documents[ws_path] @@ -380,3 +398,31 @@ async def get_room(self, ws_path: str, ydoc: Doc | None = None) -> YRoom: room = self.rooms[ws_path] await self.start_room(room) return room + + +class Task: + def __init__(self, coro, task_group: TaskGroup, cancel_event: Event | None = None): + self._coro = coro + self._cancel_event = cancel_event + self.cancelled = Event() + self.finished = Event() + task_group.start_soon(self.run) + + def cancel(self): + self.cancelled.set() + + async def run(self): + async with create_task_group() as tg: + tg.start_soon(self._run, tg) + tg.start_soon(self._check_cancellation, self.cancelled, tg) + if self._cancel_event is not None: + tg.start_soon(self._check_cancellation, self._cancel_event, tg) + self.finished.set() + + async def _run(self, tg: TaskGroup): + await self._coro + tg.cancel_scope.cancel() + + async def _check_cancellation(self, cancel_event, tg: TaskGroup): + await cancel_event.wait() + tg.cancel_scope.cancel() diff --git a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py index 7ea34ed2..ef7c1087 100644 --- a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py +++ b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py @@ -16,8 +16,7 @@ def __init__(self, ydoc: Optional[Doc] = None): @property @abstractmethod - def version(self) -> str: - ... + def version(self) -> str: ... @property def ystate(self) -> Map: @@ -60,16 +59,13 @@ def file_id(self, value: str) -> None: self._ystate["file_id"] = value @abstractmethod - def get(self) -> Any: - ... + def get(self) -> Any: ... @abstractmethod - def set(self, value: Any) -> None: - ... + def set(self, value: Any) -> None: ... @abstractmethod - def observe(self, callback: Callable[[str, Any], None]) -> None: - ... + def observe(self, callback: Callable[[str, Any], None]) -> None: ... def unobserve(self) -> None: for k, v in self._subscriptions.items(): diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py index 1e7fb5a2..755e09e0 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py @@ -25,8 +25,6 @@ class WebsocketProvider: - """WebSocket provider.""" - _ydoc: Doc _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream @@ -35,26 +33,6 @@ class WebsocketProvider: _task_group: TaskGroup | None def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -> None: - """Initialize the object. - - The WebsocketProvider instance should preferably be used as an async context manager: - ```py - async with websocket_provider: - ... - ``` - However, a lower-level API can also be used: - ```py - task = asyncio.create_task(websocket_provider.start()) - await websocket_provider.started.wait() - ... - websocket_provider.stop() - ``` - - Arguments: - ydoc: The YDoc to connect through the WebSocket. - websocket: The WebSocket through which to connect the YDoc. - log: An optional logger. - """ self._ydoc = ydoc self._websocket = websocket self.log = log or getLogger(__name__) @@ -68,7 +46,6 @@ def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) - @property def started(self) -> Event: - """An async event that is set when the WebSocket provider has started.""" if self._started is None: self._started = Event() return self._started @@ -111,11 +88,6 @@ async def _send(self): pass async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): - """Start the WebSocket provider. - - Arguments: - task_status: The status to set when the task has started. - """ if self._starting: return else: @@ -131,7 +103,6 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): task_status.started() def stop(self): - """Stop the WebSocket provider.""" if self._task_group is None: raise RuntimeError("WebsocketProvider not running") diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py index 40100211..27b69fb5 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import AsyncExitStack from logging import Logger, getLogger from anyio import TASK_STATUS_IGNORED, Event, create_task_group @@ -12,61 +11,19 @@ class WebsocketServer: - """WebSocket server.""" - auto_clean_rooms: bool rooms: dict[str, YRoom] - _started: Event | None - _starting: bool - _task_group: TaskGroup | None + _task_group: TaskGroup def __init__( self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None ) -> None: - """Initialize the object. - - The WebsocketServer instance should preferably be used as an async context manager: - ```py - async with websocket_server: - ... - ``` - However, a lower-level API can also be used: - ```py - task = asyncio.create_task(websocket_server.start()) - await websocket_server.started.wait() - ... - websocket_server.stop() - ``` - - Arguments: - rooms_ready: Whether rooms are ready to be synchronized when opened. - auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. - log: An optional logger. - """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms self.log = log or getLogger(__name__) self.rooms = {} - self._started = None - self._starting = False - self._task_group = None - - @property - def started(self) -> Event: - """An async event that is set when the WebSocket server has started.""" - if self._started is None: - self._started = Event() - return self._started async def get_room(self, name: str, ydoc: Doc | None = None) -> YRoom: - """Get or create a room with the given name, and start it. - - Arguments: - name: The room name. - - Returns: - The room with the given name, or a new one if no room with that name was found. - """ if name not in self.rooms.keys(): self.rooms[name] = YRoom(ydoc=ydoc, ready=self.rooms_ready, log=self.log) room = self.rooms[name] @@ -74,41 +31,15 @@ async def get_room(self, name: str, ydoc: Doc | None = None) -> YRoom: return room async def start_room(self, room: YRoom) -> None: - """Start a room, if not already started. - - Arguments: - room: The room to start. - """ - if self._task_group is None: - raise RuntimeError( - "The WebsocketServer is not running: use `async with websocket_server:` " - "or `await websocket_server.start()`" - ) - if not room.started.is_set(): await self._task_group.start(room.start) def get_room_name(self, room: YRoom) -> str: - """Get the name of a room. - - Arguments: - room: The room to get the name from. - - Returns: - The room name. - """ return list(self.rooms.keys())[list(self.rooms.values()).index(room)] def rename_room( self, to_name: str, *, from_name: str | None = None, from_room: YRoom | None = None ) -> None: - """Rename a room. - - Arguments: - to_name: The new name of the room. - from_name: The previous name of the room (if `from_room` is not passed). - from_room: The room to be renamed (if `from_name` is not passed). - """ if from_name is not None and from_room is not None: raise RuntimeError("Cannot pass from_name and from_room") if from_name is None: @@ -117,12 +48,6 @@ def rename_room( self.rooms[to_name] = self.rooms.pop(from_name) def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> None: - """Delete a room. - - Arguments: - name: The name of the room to delete (if `room` is not passed). - room: The room to delete ( if `name` is not passed). - """ if name is not None and room is not None: raise RuntimeError("Cannot pass name and room") if name is None: @@ -131,20 +56,15 @@ def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> room = self.rooms.pop(name) room.stop() - async def serve(self, websocket: Websocket) -> None: - """Serve a client through a WebSocket. - - Arguments: - websocket: The WebSocket through which to serve the client. - """ - if self._task_group is None: - raise RuntimeError( - "The WebsocketServer is not running: use `async with websocket_server:` " - "or `await websocket_server.start()`" - ) - + async def serve(self, websocket: Websocket, stop_event: Event | None = None) -> None: async with create_task_group() as tg: tg.start_soon(self._serve, websocket, tg) + if stop_event is not None: + tg.start_soon(self._watch_stop, tg, stop_event) + + async def _watch_stop(self, tg: TaskGroup, stop_event: Event): + await stop_event.wait() + tg.cancel_scope.cancel() async def _serve(self, websocket: Websocket, tg: TaskGroup): room = await self.get_room(websocket.path) @@ -155,51 +75,12 @@ async def _serve(self, websocket: Websocket, tg: TaskGroup): self.delete_room(room=room) tg.cancel_scope.cancel() - async def __aenter__(self) -> WebsocketServer: - if self._task_group is not None: - raise RuntimeError("WebsocketServer already running") - - async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) - self._exit_stack = exit_stack.pop_all() - self.started.set() - - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - if self._task_group is None: - raise RuntimeError("WebsocketServer not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None - return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): - """Start the WebSocket server. - - Arguments: - task_status: The status to set when the task has started. - """ - if self._starting: - return - else: - self._starting = True - - if self._task_group is not None: - raise RuntimeError("WebsocketServer already running") - # create the task group and wait forever - async with create_task_group() as self._task_group: - self._task_group.start_soon(Event().wait) - self.started.set() - self._starting = False + async with create_task_group() as tg: + self._task_group = tg + tg.start_soon(Event().wait) task_status.started() - def stop(self) -> None: - """Stop the WebSocket server.""" - if self._task_group is None: - raise RuntimeError("WebsocketServer not running") - + async def stop(self) -> None: self._task_group.cancel_scope.cancel() - self._task_group = None diff --git a/plugins/yjs/fps_yjs/ywebsocket/yroom.py b/plugins/yjs/fps_yjs/ywebsocket/yroom.py index 15fd41de..6a2bb01b 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/yroom.py +++ b/plugins/yjs/fps_yjs/ywebsocket/yroom.py @@ -191,11 +191,6 @@ def stop(self): self._task_group = None async def serve(self, websocket: Websocket): - """Serve a client. - - Arguments: - websocket: The WebSocket through which to serve the client. - """ async with create_task_group() as tg: self.clients.append(websocket) await sync(self.ydoc, websocket, self.log) diff --git a/plugins/yjs/fps_yjs/ywebsocket/ystore.py b/plugins/yjs/fps_yjs/ywebsocket/ystore.py index 127a542e..1615b91b 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/ystore.py +++ b/plugins/yjs/fps_yjs/ywebsocket/ystore.py @@ -10,11 +10,11 @@ from pathlib import Path from typing import AsyncIterator, Awaitable, Callable, cast -import aiosqlite import anyio from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus from pycrdt import Doc +from sqlite_anyio import connect from .yutils import Decoder, get_new_path, write_var_uint @@ -33,16 +33,13 @@ class BaseYStore(ABC): @abstractmethod def __init__( self, path: str, metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None - ): - ... + ): ... @abstractmethod - async def write(self, data: bytes) -> None: - ... + async def write(self, data: bytes) -> None: ... @abstractmethod - async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: - ... + async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: ... @property def started(self) -> Event: @@ -58,16 +55,12 @@ async def __aenter__(self) -> BaseYStore: tg = create_task_group() self._task_group = await exit_stack.enter_async_context(tg) self._exit_stack = exit_stack.pop_all() - tg.start_soon(self.start) + await tg.start(self.start) return self async def __aexit__(self, exc_type, exc_value, exc_tb): - if self._task_group is None: - raise RuntimeError("YStore not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None + await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): @@ -78,8 +71,8 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): """ if self._starting: return - else: - self._starting = True + + self._starting = True if self._task_group is not None: raise RuntimeError("YStore already running") @@ -88,7 +81,7 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._starting = False task_status.started() - def stop(self) -> None: + async def stop(self) -> None: """Stop the store.""" if self._task_group is None: raise RuntimeError("YStore not running") @@ -327,19 +320,14 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): Arguments: task_status: The status to set when the task has started. """ - if self._starting: - return - else: - self._starting = True + self._db = await connect(self.db_path) + await self._init_db() + await super().start(task_status=task_status) - if self._task_group is not None: - raise RuntimeError("YStore already running") - - async with create_task_group() as self._task_group: - self._task_group.start_soon(self._init_db) - self.started.set() - self._starting = False - task_status.started() + async def stop(self) -> None: + """Stop the store.""" + await self._db.close() + await super().stop() async def _init_db(self): create_db = False @@ -348,36 +336,36 @@ async def _init_db(self): create_db = True else: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - "SELECT count(name) FROM sqlite_master " - "WHERE type='table' and name='yupdates'" - ) - table_exists = (await cursor.fetchone())[0] - if table_exists: - cursor = await db.execute("pragma user_version") - version = (await cursor.fetchone())[0] - if version != self.version: - move_db = True - create_db = True - else: + cursor = await self._db.cursor() + await cursor.execute( + "SELECT count(name) FROM sqlite_master " + "WHERE type='table' and name='yupdates'" + ) + table_exists = (await cursor.fetchone())[0] + if table_exists: + await cursor.execute("pragma user_version") + version = (await cursor.fetchone())[0] + if version != self.version: + move_db = True create_db = True + else: + create_db = True if move_db: new_path = await get_new_path(self.db_path) self.log.warning(f"YStore version mismatch, moving {self.db_path} to {new_path}") await anyio.Path(self.db_path).rename(new_path) if create_db: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "CREATE TABLE yupdates " - "(path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" - ) - await db.execute( - "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" - ) - await db.execute(f"PRAGMA user_version = {self.version}") - await db.commit() + cursor = await self._db.cursor() + await cursor.execute( + "CREATE TABLE yupdates " + "(path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" + ) + await cursor.execute( + "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await cursor.execute(f"PRAGMA user_version = {self.version}") + await self._db.commit() self.db_initialized.set() async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore @@ -389,17 +377,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: igno await self.db_initialized.wait() try: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", - (self.path,), - ) as cursor: - found = False - async for update, metadata, timestamp in cursor: - found = True - yield update, metadata, timestamp - if not found: - raise YDocNotFound + cursor = await self._db.cursor() + await cursor.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (self.path,), + ) + found = False + for update, metadata, timestamp in await cursor.fetchall(): + found = True + yield update, metadata, timestamp + if not found: + raise YDocNotFound except Exception: raise YDocNotFound @@ -411,37 +399,35 @@ async def write(self, data: bytes) -> None: """ await self.db_initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - # first, determine time elapsed since last update - cursor = await db.execute( - "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Doc() - async with db.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) - ) as cursor: - async for update, in cursor: - ydoc.apply_update(update) - # delete history - await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = ydoc.get_update() - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), - ) - - # finally, write this update to the DB + # first, determine time elapsed since last update + cursor = await self._db.cursor() + await cursor.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Doc() + await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)) + for (update,) in await cursor.fetchall(): + ydoc.apply_update(update) + # delete history + await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = ydoc.get_update() metadata = await self.get_metadata() - await db.execute( + await cursor.execute( "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), + (self.path, squashed_update, metadata, time.time()), ) - await db.commit() + + # finally, write this update to the DB + metadata = await self.get_metadata() + await cursor.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, data, metadata, time.time()), + ) + await self._db.commit() diff --git a/plugins/yjs/fps_yjs/ywebsocket/yutils.py b/plugins/yjs/fps_yjs/ywebsocket/yutils.py index fe731116..5ccec736 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/yutils.py +++ b/plugins/yjs/fps_yjs/ywebsocket/yutils.py @@ -4,6 +4,7 @@ from pathlib import Path import anyio +from anyio import BrokenResourceError from anyio.streams.memory import MemoryObjectSendStream from pycrdt import Doc, TransactionEvent @@ -99,7 +100,10 @@ def read_var_string(self): def put_updates(update_send_stream: MemoryObjectSendStream, event: TransactionEvent) -> None: update = event.update # type: ignore - update_send_stream.send_nowait(update) + try: + update_send_stream.send_nowait(update) + except BrokenResourceError: + pass async def process_sync_message(message: bytes, ydoc: Doc, websocket, log) -> None: diff --git a/plugins/yjs/fps_yjs/ywidgets/widgets.py b/plugins/yjs/fps_yjs/ywidgets/widgets.py index 52eeae03..206e9560 100644 --- a/plugins/yjs/fps_yjs/ywidgets/widgets.py +++ b/plugins/yjs/fps_yjs/ywidgets/widgets.py @@ -11,6 +11,7 @@ process_sync_message, sync, ) + ypywidgets_installed = True except ImportError: ypywidgets_installed = False @@ -24,11 +25,10 @@ Widgets: Any if ypywidgets_installed: + class Widgets: # type: ignore def __init__(self): - self.ydocs = { - ep.name: ep.load() for ep in entry_points(group="ypywidgets") - } + self.ydocs = {ep.name: ep.load() for ep in entry_points(group="ypywidgets")} self.widgets = {} def comm_open(self, msg, comm) -> None: diff --git a/plugins/yjs/pyproject.toml b/plugins/yjs/pyproject.toml index 94ac7f81..b45cc29d 100644 --- a/plugins/yjs/pyproject.toml +++ b/plugins/yjs/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "importlib_metadata >=3.6; python_version<'3.10'", "pycrdt >=0.8.16,<0.9.0", "jupyverse-api >=0.1.2,<1", + "sqlite-anyio >=0.2.0,<0.3.0", ] dynamic = [ "version",] [[project.authors]] diff --git a/pyproject.toml b/pyproject.toml index 8486f9f2..e7e702a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,12 @@ docs = [ "mkdocs", "mkdocs-material" ] # pre-install commands here and post-install commands in the matrix can be moved # to the dependencies section pre-install-commands = [ + "pip install git+https://github.com/asphalt-framework/asphalt.git@5.0", + "pip install git+https://github.com/asphalt-framework/asphalt-web.git@asphalt5", + "pip install asgiref", + "pip install fastapi", + "pip install hypercorn", + "pip install -e ./jupyverse_api", "pip install -e ./plugins/contents", "pip install -e ./plugins/frontend", @@ -87,6 +93,7 @@ matrix.frontend.scripts = [ { key = "typecheck1", value = "typecheck0 ./plugins/jupyterlab", if = ["jupyterlab"] }, { key = "typecheck1", value = "typecheck0 ./plugins/notebook", if = ["notebook"] }, ] + matrix.auth.post-install-commands = [ { value = "pip install -e ./plugins/noauth", if = ["noauth"] }, { value = "pip install -e ./plugins/auth -e ./plugins/login", if = ["auth"] }, @@ -181,15 +188,3 @@ python_packages = [ [tool.hatch.version] path = "jupyverse/__init__.py" - -[tool.pytest.ini_options] -asyncio_mode = "strict" - -[tool.pixi.project] -name = "" -channels = ["conda-forge"] -platforms = ["linux-64"] - -[tool.pixi.dependencies] -pip = ">=24.0,<25" -python = "<3.12" diff --git a/tests/conftest.py b/tests/conftest.py index 748a3e70..983395db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import signal import subprocess import time from pathlib import Path @@ -7,6 +8,12 @@ import requests +@pytest.fixture +def anyio_backend(): + # at least, SQLAlchemy doesn't support anything else than asyncio + return "asyncio" + + @pytest.fixture() def cwd(): return Path(__file__).parents[1] @@ -38,5 +45,5 @@ def start_jupyverse(auth_mode, clear_users, cwd, unused_tcp_port): else: break yield url - p.kill() + os.kill(p.pid, signal.SIGINT) p.wait() diff --git a/tests/data/notebook1.ipynb b/tests/data/notebook1.ipynb index e1c94429..6ea750cb 100644 --- a/tests/data/notebook1.ipynb +++ b/tests/data/notebook1.ipynb @@ -1,53 +1,56 @@ { "cells": [ { - "execution_count": null, + "execution_count": 1, "outputs": [], "id": "a7243792-6f06-4462-a6b5-7e9ec604348e", "source": "from ypywidgets_textual.switch import Switch", - "cell_type": "code", + "execution_state": "idle", "metadata": { "trusted": false - } + }, + "cell_type": "code" }, { + "execution_count": 2, + "cell_type": "code", + "outputs": [], + "execution_state": "busy", "id": "a7243792-6f06-4462-a6b5-7e9ec604348f", - "source": "switch = Switch()\nswitch", - "execution_count": null, "metadata": { "trusted": false }, - "outputs": [], - "cell_type": "code" + "source": "switch = Switch()\nswitch" }, { + "execution_state": "idle", "outputs": [], "id": "a7243792-6f06-4462-a6b5-7e9ec604349f", "source": "switch.toggle()", "cell_type": "code", + "execution_count": 3, "metadata": { "trusted": false - }, - "execution_count": null + } } ], "metadata": { "kernelspec": { - "language": "python", + "display_name": "Python 3 (ipykernel)", "name": "python3", - "display_name": "Python 3 (ipykernel)" + "language": "python" }, "language_info": { "version": "3.7.12", + "pygments_lexer": "ipython3", + "name": "python", + "nbconvert_exporter": "python", + "mimetype": "text/x-python", + "file_extension": ".py", "codemirror_mode": { "version": 3, "name": "ipython" - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "pygments_lexer": "ipython3", - "nbconvert_exporter": "python" + } } }, "nbformat": 4, diff --git a/tests/test_app.py b/tests/test_app.py index dfd97365..ff62fc0a 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,5 +1,5 @@ import pytest -from asphalt.core import Context +from asphalt.core import Context, get_resource from fastapi import APIRouter from httpx import AsyncClient from utils import configure @@ -9,7 +9,7 @@ from jupyverse_api.main import JupyverseComponent -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize( "mount_path", ( @@ -20,13 +20,13 @@ async def test_mount_path(mount_path, unused_tcp_port): components = configure({"app": {"type": "app"}}, {"app": {"mount_path": mount_path}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() - app = await ctx.request_resource(App) + app = await get_resource(App, wait=True) router = APIRouter() @router.get("/") diff --git a/tests/test_auth.py b/tests/test_auth.py index e8a3b5ed..9e35fa25 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,5 +1,5 @@ import pytest -from asphalt.core import Context +from asphalt.core import Context, get_resource from httpx import AsyncClient from httpx_ws import WebSocketUpgradeError, aconnect_ws from utils import authenticate_client, configure @@ -19,13 +19,13 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio async def test_kernel_channels_unauthenticated(unused_tcp_port): - async with Context() as ctx: + async with Context(): await JupyverseComponent( components=COMPONENTS, port=unused_tcp_port, - ).start(ctx) + ).start() with pytest.raises(WebSocketUpgradeError): async with aconnect_ws( @@ -34,13 +34,13 @@ async def test_kernel_channels_unauthenticated(unused_tcp_port): pass -@pytest.mark.asyncio +@pytest.mark.anyio async def test_kernel_channels_authenticated(unused_tcp_port): - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=COMPONENTS, port=unused_tcp_port, - ).start(ctx) + ).start() await authenticate_client(http, unused_tcp_port) async with aconnect_ws( @@ -50,15 +50,15 @@ async def test_kernel_channels_authenticated(unused_tcp_port): pass -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth", "token", "user")) async def test_root_auth(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/") if auth_mode == "noauth": @@ -70,31 +70,31 @@ async def test_root_auth(auth_mode, unused_tcp_port): assert response.headers["content-type"] == "application/json" -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_no_auth(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/lab") assert response.status_code == 200 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("token",)) async def test_token_auth(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() - auth_config = await ctx.request_resource(AuthConfig) + auth_config = await get_resource(AuthConfig, wait=True) # no token provided, should not work response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/") @@ -104,7 +104,7 @@ async def test_token_auth(auth_mode, unused_tcp_port): assert response.status_code == 302 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("user",)) @pytest.mark.parametrize( "permissions", @@ -115,11 +115,11 @@ async def test_token_auth(auth_mode, unused_tcp_port): ) async def test_permissions(auth_mode, permissions, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() await authenticate_client(http, unused_tcp_port, permissions=permissions) response = await http.get(f"http://127.0.0.1:{unused_tcp_port}/auth/user/me") diff --git a/tests/test_contents.py b/tests/test_contents.py index b44a4aac..1262bd6a 100644 --- a/tests/test_contents.py +++ b/tests/test_contents.py @@ -16,7 +16,7 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_tree(auth_mode, tmp_path, unused_tcp_port): prev_dir = os.getcwd() @@ -65,11 +65,11 @@ async def test_tree(auth_mode, tmp_path, unused_tcp_port): ) components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() response = await http.get( f"http://127.0.0.1:{unused_tcp_port}/api/contents", params={"content": 1} diff --git a/tests/test_execute.py b/tests/test_execute.py index d423f1a1..543156d3 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -1,9 +1,9 @@ -import asyncio import os from functools import partial from pathlib import Path import pytest +from anyio import create_memory_object_stream, create_task_group, sleep from asphalt.core import Context from fps_yjs.ydocs import ydocs from fps_yjs.ywebsocket import WebsocketProvider @@ -55,7 +55,7 @@ async def recv(self) -> bytes: return bytes(b) -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_execute(auth_mode, unused_tcp_port): url = f"http://127.0.0.1:{unused_tcp_port}" @@ -63,11 +63,11 @@ async def test_execute(auth_mode, unused_tcp_port): "auth": {"mode": auth_mode}, "kernels": {"require_yjs": True}, }) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() ws_url = url.replace("http", "ws", 1) name = "notebook1.ipynb" @@ -90,23 +90,24 @@ async def test_execute(auth_mode, unused_tcp_port): json={ "format": "json", "type": "notebook", - } + }, + timeout=20, ) file_id = response.json()["fileId"] document_id = f"json:notebook:{file_id}" ynb = ydocs["notebook"]() - def callback(aevent, events, event): + def callback(event_stream_send, events, event): events.append(event) - aevent.set() - aevent = asyncio.Event() + event_stream_send.send_nowait(None) + event_stream_send, event_stream_recv = create_memory_object_stream[None](1) events = [] - ynb.ydoc.observe_subdocs(partial(callback, aevent, events)) + ynb.ydoc.observe_subdocs(partial(callback, event_stream_send, events)) async with aconnect_ws( f"{ws_url}/api/collaboration/room/{document_id}" ) as websocket, WebsocketProvider(ynb.ydoc, Websocket(websocket, document_id)): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await sleep(0.5) # execute notebook for cell_idx in range(2): response = await http.post( @@ -117,23 +118,22 @@ def callback(aevent, events, event): } ) while True: - await aevent.wait() - aevent.clear() + await event_stream_recv.receive() guid = None for event in events: if event.added: guid = event.added[0] if guid is not None: break - task = asyncio.create_task(connect_ywidget(ws_url, guid)) - response = await http.post( - f"{url}/api/kernels/{kernel_id}/execute", - json={ - "document_id": document_id, - "cell_id": ynb.ycells[2]["id"], - } - ) - await task + async with create_task_group() as tg: + tg.start_soon(connect_ywidget, ws_url, guid) + response = await http.post( + f"{url}/api/kernels/{kernel_id}/execute", + json={ + "document_id": document_id, + "cell_id": ynb.ycells[2]["id"], + } + ) async def connect_ywidget(ws_url, guid): @@ -141,10 +141,8 @@ async def connect_ywidget(ws_url, guid): async with aconnect_ws( f"{ws_url}/api/collaboration/room/ywidget:{guid}" ) as websocket, WebsocketProvider(ywidget_doc, Websocket(websocket, guid)): - await asyncio.sleep(0.5) - attrs = Map() - model_name = Text() - ywidget_doc["_attrs"] = attrs - ywidget_doc["_model_name"] = model_name + await sleep(0.5) + ywidget_doc["_attrs"] = attrs = Map() + ywidget_doc["_model_name"] = model_name = Text() assert str(model_name) == "Switch" assert str(attrs) == '{"value":true}' diff --git a/tests/test_kernels.py b/tests/test_kernels.py index dba726b9..77f4cd8c 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -4,6 +4,7 @@ from time import sleep import pytest +from anyio import create_task_group from asphalt.core import Context from fps_kernels.kernel_server.server import KernelServer, kernels from httpx import AsyncClient @@ -26,9 +27,9 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_kernel_messages(auth_mode, capfd, unused_tcp_port): +async def test_kernel_messages(auth_mode, unused_tcp_port, capfd): kernel_id = "kernel_id_0" kernel_name = "python3" kernelspec_path = ( @@ -36,67 +37,70 @@ async def test_kernel_messages(auth_mode, capfd, unused_tcp_port): ) assert kernelspec_path.exists() kernel_server = KernelServer(kernelspec_path=kernelspec_path, capture_kernel_output=False) - await kernel_server.start() - kernels[kernel_id] = {"server": kernel_server} - msg_id = "0" - msg = { - "channel": "shell", - "parent_header": None, - "content": None, - "metadata": None, - "header": { - "msg_type": "msg_type_0", - "msg_id": msg_id, - }, - } + async with create_task_group() as tg: + await tg.start(kernel_server.start) + kernels[kernel_id] = {"server": kernel_server, "driver": None} + msg_id = "0" + msg = { + "channel": "shell", + "parent_header": None, + "content": None, + "metadata": None, + "header": { + "msg_type": "msg_type_0", + "msg_id": msg_id, + }, + } - components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient(): - await JupyverseComponent( - components=components, - port=unused_tcp_port, - ).start(ctx) + components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) + async with Context(), AsyncClient(): + await JupyverseComponent( + components=components, + port=unused_tcp_port, + ).start() - # block msg_type_0 - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.block_messages("msg_type_0") - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert not err + # block msg_type_0 + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.block_messages("msg_type_0") + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert "IPKernelApp" not in err - # allow only msg_type_0 - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.allow_messages("msg_type_0") - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1 + # allow only msg_type_0 + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages("msg_type_0") + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") == 1 - # block all messages - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.allow_messages([]) - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert not err + # block all messages + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages([]) + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert "IPKernelApp" not in err - # allow all messages - msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) - kernel_server.allow_messages() - async with aconnect_ws( - f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", - ) as websocket: - await websocket.send_json(msg) - sleep(0.5) - out, err = capfd.readouterr() - assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") >= 1 + # allow all messages + msg["header"]["msg_id"] = str(int(msg["header"]["msg_id"]) + 1) + kernel_server.allow_messages() + async with aconnect_ws( + f"http://127.0.0.1:{unused_tcp_port}/api/kernels/{kernel_id}/channels?session_id=session_id_0", + ) as websocket: + await websocket.send_json(msg) + sleep(0.5) + out, err = capfd.readouterr() + assert err.count("[IPKernelApp] WARNING | Unknown message type: 'msg_type_0'") >= 1 + + tg.start_soon(kernel_server.stop) diff --git a/tests/test_server.py b/tests/test_server.py index bc2325d4..6ebf11d8 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,10 +1,10 @@ -import asyncio import json from functools import partial from pathlib import Path import pytest import requests +from anyio import create_memory_object_stream, create_task_group, sleep from fps_yjs.ydocs import ydocs from fps_yjs.ywebsocket import WebsocketProvider from pycrdt import Array, Doc, Map, Text @@ -47,7 +47,7 @@ def test_settings_persistence_get(start_jupyverse): assert response.status_code == 204 -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) async def test_rest_api(start_jupyverse): @@ -87,7 +87,7 @@ async def test_rest_api(start_jupyverse): ) as websocket, WebsocketProvider(ydoc, websocket): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await sleep(0.5) ydoc["cells"] = ycells = Array() # execute notebook for cell_idx in range(3): @@ -101,7 +101,7 @@ async def test_rest_api(start_jupyverse): ), ) # wait for Y model to be updated - await asyncio.sleep(0.5) + await sleep(0.5) # retrieve cells cells = json.loads(str(ycells)) assert cells[0]["outputs"] == [ @@ -125,7 +125,7 @@ async def test_rest_api(start_jupyverse): ] -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) @pytest.mark.parametrize("clear_users", (False,)) async def test_ywidgets(start_jupyverse): @@ -139,7 +139,6 @@ async def test_ywidgets(start_jupyverse): data=json.dumps( { "kernel": {"name": "python3"}, - #"kernel": {"name": "akernel"}, "name": name, "path": path, "type": "notebook", @@ -161,18 +160,18 @@ async def test_ywidgets(start_jupyverse): file_id = response.json()["fileId"] document_id = f"json:notebook:{file_id}" ynb = ydocs["notebook"]() - def callback(aevent, events, event): + def callback(event_stream_send, events, event): events.append(event) - aevent.set() - aevent = asyncio.Event() + event_stream_send.send_nowait(None) + event_stream_send, event_stream_recv = create_memory_object_stream[None](1) events = [] - ynb.ydoc.observe_subdocs(partial(callback, aevent, events)) + ynb.ydoc.observe_subdocs(partial(callback, event_stream_send, events)) async with connect( f"{ws_url}/api/collaboration/room/{document_id}" ) as websocket, WebsocketProvider(ynb.ydoc, websocket): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await sleep(0.5) # execute notebook for cell_idx in range(2): response = requests.post( @@ -185,25 +184,24 @@ def callback(aevent, events, event): ), ) while True: - await aevent.wait() - aevent.clear() + await event_stream_recv.receive() guid = None for event in events: if event.added: guid = event.added[0] if guid is not None: break - task = asyncio.create_task(connect_ywidget(ws_url, guid)) - response = requests.post( - f"{url}/api/kernels/{kernel_id}/execute", - data=json.dumps( - { - "document_id": document_id, - "cell_id": ynb.ycells[2]["id"], - } - ), - ) - await task + async with create_task_group() as tg: + tg.start_soon(connect_ywidget, ws_url, guid) + response = requests.post( + f"{url}/api/kernels/{kernel_id}/execute", + data=json.dumps( + { + "document_id": document_id, + "cell_id": ynb.ycells[2]["id"], + } + ), + ) async def connect_ywidget(ws_url, guid): @@ -211,7 +209,7 @@ async def connect_ywidget(ws_url, guid): async with connect( f"{ws_url}/api/collaboration/room/ywidget:{guid}" ) as websocket, WebsocketProvider(ywidget_doc, websocket): - await asyncio.sleep(0.5) + await sleep(0.5) attrs = Map() model_name = Text() ywidget_doc["_attrs"] = attrs diff --git a/tests/test_settings.py b/tests/test_settings.py index 03cb6a60..1ee953f2 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -21,15 +21,15 @@ } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.parametrize("auth_mode", ("noauth",)) async def test_settings(auth_mode, unused_tcp_port): components = configure(COMPONENTS, {"auth": {"mode": auth_mode}}) - async with Context() as ctx, AsyncClient() as http: + async with Context(), AsyncClient() as http: await JupyverseComponent( components=components, port=unused_tcp_port, - ).start(ctx) + ).start() # get previous theme response = await http.get( @@ -40,7 +40,7 @@ async def test_settings(auth_mode, unused_tcp_port): # put new theme response = await http.put( f"http://127.0.0.1:{unused_tcp_port}/lab/api/settings/@jupyterlab/apputils-extension:themes", - data=json.dumps(test_theme), + content=json.dumps(test_theme), ) assert response.status_code == 204 # get new theme @@ -52,6 +52,6 @@ async def test_settings(auth_mode, unused_tcp_port): # put previous theme back response = await http.put( f"http://127.0.0.1:{unused_tcp_port}/lab/api/settings/@jupyterlab/apputils-extension:themes", - data=json.dumps(theme), + content=json.dumps(theme), ) assert response.status_code == 204