diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d08f86c9..da3fe1d0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,7 +20,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [ '3.7', '3.8', '3.9', '3.10' ] + python-version: [ '3.10' ] steps: - name: Checkout diff --git a/plugins/auth_base/fps_auth_base/py.typed b/plugins/auth_base/fps_auth_base/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/plugins/contents/fps_contents/__init__.py b/plugins/contents/fps_contents/__init__.py index fa571767..b01223b7 100644 --- a/plugins/contents/fps_contents/__init__.py +++ b/plugins/contents/fps_contents/__init__.py @@ -1 +1,3 @@ +from .fileid import get_watch # noqa + __version__ = "0.0.44" diff --git a/plugins/contents/fps_contents/fileid.py b/plugins/contents/fps_contents/fileid.py new file mode 100644 index 00000000..51e27e11 --- /dev/null +++ b/plugins/contents/fps_contents/fileid.py @@ -0,0 +1,175 @@ +import asyncio +from pathlib import Path +from typing import Callable, Dict, List, Optional +from uuid import uuid4 + +import aiosqlite +from aiopath import AsyncPath # type: ignore +from watchfiles import Change, awatch + + +class Watcher: + def __init__(self, path: str) -> None: + self.path = path + self._event = asyncio.Event() + + def __aiter__(self): + return self + + async def __anext__(self): + await self._event.wait() + self._event.clear() + return self._change + + def notify(self, change): + self._change = change + self._event.set() + + +class FileIdManager: + + db_path: str + initialized: asyncio.Event + watchers: Dict[str, List[Watcher]] + + def __init__(self, db_path: str = "fileid.db"): + self.db_path = db_path + self.initialized = asyncio.Event() + self.watchers = {} + asyncio.create_task(self.watch_files()) + + async def get_id(self, path: str) -> Optional[str]: + await self.initialized.wait() + async with aiosqlite.connect(self.db_path) as db: + async with db.execute("SELECT id FROM fileids WHERE path = ?", (path,)) as cursor: + async for idx, in cursor: + return idx + return None + + async def get_path(self, idx: str) -> Optional[str]: + await self.initialized.wait() + async with aiosqlite.connect(self.db_path) as db: + async with db.execute("SELECT path FROM fileids WHERE id = ?", (idx,)) as cursor: + async for path, in cursor: + return path + return None + + async def index(self, path: str) -> Optional[str]: + await self.initialized.wait() + async with aiosqlite.connect(self.db_path) as db: + apath = AsyncPath(path) + if not await apath.exists(): + return None + + idx = uuid4().hex + mtime = (await apath.stat()).st_mtime + await db.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, path, mtime)) + await db.commit() + return idx + + async def watch_files(self): + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "CREATE TABLE IF NOT EXISTS fileids " + "(id TEXT PRIMARY KEY, path TEXT NOT NULL, mtime REAL NOT NULL)" + ) + await db.commit() + + async with aiosqlite.connect(self.db_path) as db: + async for path in AsyncPath().rglob("*"): + idx = uuid4().hex + mtime = (await path.stat()).st_mtime + await db.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, str(path), mtime)) + await db.commit() + self.initialized.set() + + async for changes in awatch("."): + deleted_paths = [] + added_paths = [] + for change, changed_path in changes: + # get relative path + changed_path = str(Path(changed_path).relative_to(Path().absolute())) + + if change == Change.deleted: + async with db.execute( + "SELECT * FROM fileids WHERE path = ?", (changed_path,) + ) as cursor: + async for _ in cursor: + break + else: + # path is not indexed, ignore + continue + # path is indexed + await maybe_rename(db, changed_path, deleted_paths, added_paths, False) + elif change == Change.added: + await maybe_rename(db, changed_path, added_paths, deleted_paths, True) + + for path in deleted_paths + added_paths: + await db.execute("DELETE FROM fileids WHERE path = ?", (path,)) + await db.commit() + + for change in changes: + changed_path = change[1] + # get relative path + changed_path = str(Path(changed_path).relative_to(Path().absolute())) + for watcher in self.watchers.get(changed_path, []): + watcher.notify(change) + + def watch(self, path: str) -> Watcher: + watcher = Watcher(path) + if path not in self.watchers: + self.watchers[path] = watchers = [] # type: ignore + watchers.append(watcher) + return watcher + + +async def get_mtime(path, db) -> Optional[float]: + if db: + async with db.execute("SELECT * FROM fileids WHERE path = ?", (path,)) as cursor: + async for _, _, mtime in cursor: + return mtime + # deleted file is not in database, shouldn't happen + return None + try: + mtime = (await AsyncPath(path).stat()).st_mtime + except FileNotFoundError: + return None + return mtime + + +async def maybe_rename( + db, changed_path: str, changed_paths: List[str], other_paths: List[str], is_added_path +) -> None: + # check if the same file was added/deleted, this would be a rename + db_or_fs1, db_or_fs2 = db, None + if is_added_path: + db_or_fs1, db_or_fs2 = db_or_fs2, db_or_fs1 + mtime1 = await get_mtime(changed_path, db_or_fs1) + if mtime1 is None: + return + for other_path in other_paths: + mtime2 = await get_mtime(other_path, db_or_fs2) + if mtime1 == mtime2: + # same files, according to modification times + path1, path2 = changed_path, other_path + if is_added_path: + path1, path2 = path2, path1 + await db.execute("UPDATE fileids SET path = REPLACE(path, ?, ?)", (path1, path2)) + other_paths.remove(other_path) + return + changed_paths.append(changed_path) + + +FILE_ID_MANAGER: Optional[FileIdManager] = None + + +def get_file_id_manager() -> FileIdManager: + global FILE_ID_MANAGER + if FILE_ID_MANAGER is None: + FILE_ID_MANAGER = FileIdManager() + assert FILE_ID_MANAGER is not None + return FILE_ID_MANAGER + + +def get_watch() -> Callable[[str], Watcher]: + return get_file_id_manager().watch diff --git a/plugins/contents/pyproject.toml b/plugins/contents/pyproject.toml index 3e206ccb..849e4a52 100644 --- a/plugins/contents/pyproject.toml +++ b/plugins/contents/pyproject.toml @@ -7,7 +7,7 @@ name = "fps_contents" description = "An FPS plugin for the contents API" keywords = [ "jupyter", "server", "fastapi", "pluggy", "plugins",] requires-python = ">=3.7" -dependencies = [ "fps >=0.0.8", "fps-auth-base", "anyio", "watchfiles >=0.16.1,<1",] +dependencies = [ "fps >=0.0.8", "fps-auth-base", "anyio", "watchfiles >=0.16.1,<1", "aiosqlite >=0.17.0,<1", "aiopath >=0.6.11,<1"] dynamic = [ "version",] [[project.authors]] name = "Jupyter Development Team" diff --git a/plugins/yjs/fps_yjs/models.py b/plugins/yjs/fps_yjs/models.py new file mode 100644 index 00000000..1fa3b71c --- /dev/null +++ b/plugins/yjs/fps_yjs/models.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class CreateRoomId(BaseModel): + format: str + type: str diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 5b0ed073..b5e199d0 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -3,22 +3,36 @@ from pathlib import Path from typing import Optional, Tuple -from fastapi import APIRouter, Depends, WebSocketDisconnect +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + Response, + WebSocketDisconnect, + status, +) +from fastapi.responses import PlainTextResponse from fps.hooks import register_router # type: ignore +from fps_auth_base import User, current_user +from fps_contents.fileid import get_file_id_manager from fps_contents.routes import read_content, write_content # type: ignore try: - from fps_contents.watchfiles import awatch + from fps_contents import get_watch - has_awatch = True + has_watch = True except ImportError: - has_awatch = False + has_watch = False + from fps_auth_base import websocket_auth # type: ignore from jupyter_ydoc import ydocs as YDOCS # type: ignore from ypy_websocket.websocket_server import WebsocketServer, YRoom # type: ignore from ypy_websocket.ystore import BaseYStore, SQLiteYStore, YDocNotFound # type: ignore from ypy_websocket.yutils import YMessageType # type: ignore +from .models import CreateRoomId + YFILE = YDOCS["file"] AWARENESS = 1 RENAME_SESSION = 127 @@ -126,9 +140,14 @@ def __init__(self, websocket, path, permissions): self.room = self.websocket_server.get_room(self.websocket.path) self.set_file_info(path) - def get_file_info(self) -> Tuple[str, str, str]: + async def get_file_info(self) -> Tuple[str, str, str]: room_name = self.websocket_server.get_room_name(self.room) - file_format, file_type, file_path = room_name.split(":", 2) + file_format, file_type, file_id = room_name.split(":", 2) + file_path = await get_file_id_manager().get_path(file_id) + if file_path is None: + raise RuntimeError(f"File {self.room.document.path} cannot be found anymore") + if file_path != self.room.document.path: + self.room.document.path = file_path return file_format, file_type, file_path def set_file_info(self, value: str) -> None: @@ -145,7 +164,7 @@ async def serve(self): self.room.cleaner.cancel() if not self.room.is_transient and not self.room.ready: - file_format, file_type, file_path = self.get_file_info() + file_format, file_type, file_path = await self.get_file_info() is_notebook = file_type == "notebook" model = await read_content(file_path, True, as_json=is_notebook) self.last_modified = to_datetime(model.last_modified) @@ -212,9 +231,9 @@ async def on_message(self, message: bytes) -> bool: return skip async def watch_file(self): - if has_awatch: - file_format, file_type, file_path = self.get_file_info() - async for changes in awatch(file_path): + if has_watch: + file_format, file_type, file_path = await self.get_file_info() + async for changes in get_watch()(file_path): await self.maybe_load_document() else: # contents plugin doesn't provide watcher, fall back to polling @@ -227,7 +246,7 @@ async def watch_file(self): await self.maybe_load_document() async def maybe_load_document(self): - file_format, file_type, file_path = self.get_file_info() + file_format, file_type, file_path = await self.get_file_info() model = await read_content(file_path, False) # do nothing if the file was saved by us if self.last_modified < to_datetime(model.last_modified): @@ -266,7 +285,7 @@ async def maybe_save_document(self): await asyncio.sleep(1) # if the room cannot be found, don't save try: - file_format, file_type, file_path = self.get_file_info() + file_format, file_type, file_path = await self.get_file_info() except Exception: return is_notebook = file_type == "notebook" @@ -293,4 +312,27 @@ async def maybe_save_document(self): self.room.document.dirty = False +@router.put("/api/yjs/roomid/{path:path}", status_code=200, response_class=PlainTextResponse) +async def create_roomid( + path, + request: Request, + response: Response, + user: User = Depends(current_user(permissions={"contents": ["read"]})), +): + # we need to process the request manually + # see https://github.com/tiangolo/fastapi/issues/3373#issuecomment-1306003451 + create_room_id = CreateRoomId(**(await request.json())) + ws_url = f"{create_room_id.format}:{create_room_id.type}:" + idx = await get_file_id_manager().get_id(path) + if idx is not None: + return ws_url + idx + + idx = await get_file_id_manager().index(path) + if idx is None: + raise HTTPException(status_code=404, detail=f"File {path} does not exist") + + response.status_code = status.HTTP_201_CREATED + return ws_url + idx + + r = register_router(router) diff --git a/tests/test_server.py b/tests/test_server.py index b048b459..deee153f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,7 @@ import asyncio import json import sys +from pathlib import Path import pytest import requests @@ -55,21 +56,33 @@ def test_settings_persistence_get(start_jupyverse): async def test_rest_api(start_jupyverse): url = start_jupyverse ws_url = url.replace("http", "ws", 1) + name = "notebook0.ipynb" + path = (Path("tests") / "data" / name).as_posix() # create a session to launch a kernel response = requests.post( f"{url}/api/sessions", data=json.dumps( { "kernel": {"name": "python3"}, - "name": "notebook0.ipynb", - "path": "69e8a762-86c6-4102-a3da-a43d735fec2b", + "name": name, + "path": path, "type": "notebook", } ), ) r = response.json() kernel_id = r["kernel"]["id"] - document_id = "json:notebook:tests/data/notebook0.ipynb" + # get the room ID for the document + response = requests.put( + f"{url}/api/yjs/roomid/{path}", + data=json.dumps( + { + "format": "json", + "type": "notebook", + } + ), + ) + document_id = response.text async with connect(f"{ws_url}/api/yjs/{document_id}") as websocket: # connect to the shared notebook document ydoc = Y.YDoc()