Skip to content

Commit

Permalink
Implement file ID service
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 10, 2022
1 parent 9e5009f commit a555dc1
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [ '3.7', '3.8', '3.9', '3.10' ]
python-version: [ '3.10' ]

steps:
- name: Checkout
Expand Down
Empty file.
2 changes: 2 additions & 0 deletions plugins/contents/fps_contents/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .fileid import get_watch # noqa

__version__ = "0.0.44"
175 changes: 175 additions & 0 deletions plugins/contents/fps_contents/fileid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import asyncio
from pathlib import Path
from typing import Callable, Dict, List, Optional
from uuid import uuid4

import aiosqlite
from aiopath import AsyncPath # type: ignore
from watchfiles import Change, awatch


class Watcher:
def __init__(self, path: str) -> None:
self.path = path
self._event = asyncio.Event()

def __aiter__(self):
return self

async def __anext__(self):
await self._event.wait()
self._event.clear()
return self._change

def notify(self, change):
self._change = change
self._event.set()


class FileIdManager:

db_path: str
initialized: asyncio.Event
watchers: Dict[str, List[Watcher]]

def __init__(self, db_path: str = "fileid.db"):
self.db_path = db_path
self.initialized = asyncio.Event()
self.watchers = {}
asyncio.create_task(self.watch_files())

async def get_id(self, path: str) -> Optional[str]:
await self.initialized.wait()
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("SELECT id FROM fileids WHERE path = ?", (path,)) as cursor:
async for idx, in cursor:
return idx
return None

async def get_path(self, idx: str) -> Optional[str]:
await self.initialized.wait()
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("SELECT path FROM fileids WHERE id = ?", (idx,)) as cursor:
async for path, in cursor:
return path
return None

async def index(self, path: str) -> Optional[str]:
await self.initialized.wait()
async with aiosqlite.connect(self.db_path) as db:
apath = AsyncPath(path)
if not await apath.exists():
return None

idx = uuid4().hex
mtime = (await apath.stat()).st_mtime
await db.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, path, mtime))
await db.commit()
return idx

async def watch_files(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE IF NOT EXISTS fileids "
"(id TEXT PRIMARY KEY, path TEXT NOT NULL, mtime REAL NOT NULL)"
)
await db.commit()

async with aiosqlite.connect(self.db_path) as db:
async for path in AsyncPath().rglob("*"):
idx = uuid4().hex
mtime = (await path.stat()).st_mtime
await db.execute("INSERT INTO fileids VALUES (?, ?, ?)", (idx, str(path), mtime))
await db.commit()
self.initialized.set()

async for changes in awatch("."):
deleted_paths = []
added_paths = []
for change, changed_path in changes:
# get relative path
changed_path = str(Path(changed_path).relative_to(Path().absolute()))

if change == Change.deleted:
async with db.execute(
"SELECT * FROM fileids WHERE path = ?", (changed_path,)
) as cursor:
async for _ in cursor:
break
else:
# path is not indexed, ignore
continue
# path is indexed
await maybe_rename(db, changed_path, deleted_paths, added_paths, False)
elif change == Change.added:
await maybe_rename(db, changed_path, added_paths, deleted_paths, True)

for path in deleted_paths + added_paths:
await db.execute("DELETE FROM fileids WHERE path = ?", (path,))
await db.commit()

for change in changes:
changed_path = change[1]
# get relative path
changed_path = str(Path(changed_path).relative_to(Path().absolute()))
for watcher in self.watchers.get(changed_path, []):
watcher.notify(change)

def watch(self, path: str) -> Watcher:
watcher = Watcher(path)
if path not in self.watchers:
self.watchers[path] = watchers = [] # type: ignore
watchers.append(watcher)
return watcher


async def get_mtime(path, db) -> Optional[float]:
if db:
async with db.execute("SELECT * FROM fileids WHERE path = ?", (path,)) as cursor:
async for _, _, mtime in cursor:
return mtime
# deleted file is not in database, shouldn't happen
return None
try:
mtime = (await AsyncPath(path).stat()).st_mtime
except FileNotFoundError:
return None
return mtime


async def maybe_rename(
db, changed_path: str, changed_paths: List[str], other_paths: List[str], is_added_path
) -> None:
# check if the same file was added/deleted, this would be a rename
db_or_fs1, db_or_fs2 = db, None
if is_added_path:
db_or_fs1, db_or_fs2 = db_or_fs2, db_or_fs1
mtime1 = await get_mtime(changed_path, db_or_fs1)
if mtime1 is None:
return
for other_path in other_paths:
mtime2 = await get_mtime(other_path, db_or_fs2)
if mtime1 == mtime2:
# same files, according to modification times
path1, path2 = changed_path, other_path
if is_added_path:
path1, path2 = path2, path1
await db.execute("UPDATE fileids SET path = REPLACE(path, ?, ?)", (path1, path2))
other_paths.remove(other_path)
return
changed_paths.append(changed_path)


FILE_ID_MANAGER: Optional[FileIdManager] = None


def get_file_id_manager() -> FileIdManager:
global FILE_ID_MANAGER
if FILE_ID_MANAGER is None:
FILE_ID_MANAGER = FileIdManager()
assert FILE_ID_MANAGER is not None
return FILE_ID_MANAGER


def get_watch() -> Callable[[str], Watcher]:
return get_file_id_manager().watch
2 changes: 1 addition & 1 deletion plugins/contents/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "fps_contents"
description = "An FPS plugin for the contents API"
keywords = [ "jupyter", "server", "fastapi", "pluggy", "plugins",]
requires-python = ">=3.7"
dependencies = [ "fps >=0.0.8", "fps-auth-base", "anyio", "watchfiles >=0.16.1,<1",]
dependencies = [ "fps >=0.0.8", "fps-auth-base", "anyio", "watchfiles >=0.16.1,<1", "aiosqlite >=0.17.0,<1", "aiopath >=0.6.11,<1"]
dynamic = [ "version",]
[[project.authors]]
name = "Jupyter Development Team"
Expand Down
6 changes: 6 additions & 0 deletions plugins/yjs/fps_yjs/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


class CreateRoomId(BaseModel):
format: str
type: str
66 changes: 54 additions & 12 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,36 @@
from pathlib import Path
from typing import Optional, Tuple

from fastapi import APIRouter, Depends, WebSocketDisconnect
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
Response,
WebSocketDisconnect,
status,
)
from fastapi.responses import PlainTextResponse
from fps.hooks import register_router # type: ignore
from fps_auth_base import User, current_user
from fps_contents.fileid import get_file_id_manager
from fps_contents.routes import read_content, write_content # type: ignore

try:
from fps_contents.watchfiles import awatch
from fps_contents import get_watch

has_awatch = True
has_watch = True
except ImportError:
has_awatch = False
has_watch = False

from fps_auth_base import websocket_auth # type: ignore
from jupyter_ydoc import ydocs as YDOCS # type: ignore
from ypy_websocket.websocket_server import WebsocketServer, YRoom # type: ignore
from ypy_websocket.ystore import BaseYStore, SQLiteYStore, YDocNotFound # type: ignore
from ypy_websocket.yutils import YMessageType # type: ignore

from .models import CreateRoomId

YFILE = YDOCS["file"]
AWARENESS = 1
RENAME_SESSION = 127
Expand Down Expand Up @@ -126,9 +140,14 @@ def __init__(self, websocket, path, permissions):
self.room = self.websocket_server.get_room(self.websocket.path)
self.set_file_info(path)

def get_file_info(self) -> Tuple[str, str, str]:
async def get_file_info(self) -> Tuple[str, str, str]:
room_name = self.websocket_server.get_room_name(self.room)
file_format, file_type, file_path = room_name.split(":", 2)
file_format, file_type, file_id = room_name.split(":", 2)
file_path = await get_file_id_manager().get_path(file_id)
if file_path is None:
raise RuntimeError(f"File {self.room.document.path} cannot be found anymore")
if file_path != self.room.document.path:
self.room.document.path = file_path
return file_format, file_type, file_path

def set_file_info(self, value: str) -> None:
Expand All @@ -145,7 +164,7 @@ async def serve(self):
self.room.cleaner.cancel()

if not self.room.is_transient and not self.room.ready:
file_format, file_type, file_path = self.get_file_info()
file_format, file_type, file_path = await self.get_file_info()
is_notebook = file_type == "notebook"
model = await read_content(file_path, True, as_json=is_notebook)
self.last_modified = to_datetime(model.last_modified)
Expand Down Expand Up @@ -212,9 +231,9 @@ async def on_message(self, message: bytes) -> bool:
return skip

async def watch_file(self):
if has_awatch:
file_format, file_type, file_path = self.get_file_info()
async for changes in awatch(file_path):
if has_watch:
file_format, file_type, file_path = await self.get_file_info()
async for changes in get_watch()(file_path):
await self.maybe_load_document()
else:
# contents plugin doesn't provide watcher, fall back to polling
Expand All @@ -227,7 +246,7 @@ async def watch_file(self):
await self.maybe_load_document()

async def maybe_load_document(self):
file_format, file_type, file_path = self.get_file_info()
file_format, file_type, file_path = await self.get_file_info()
model = await read_content(file_path, False)
# do nothing if the file was saved by us
if self.last_modified < to_datetime(model.last_modified):
Expand Down Expand Up @@ -266,7 +285,7 @@ async def maybe_save_document(self):
await asyncio.sleep(1)
# if the room cannot be found, don't save
try:
file_format, file_type, file_path = self.get_file_info()
file_format, file_type, file_path = await self.get_file_info()
except Exception:
return
is_notebook = file_type == "notebook"
Expand All @@ -293,4 +312,27 @@ async def maybe_save_document(self):
self.room.document.dirty = False


@router.put("/api/yjs/roomid/{path:path}", status_code=200, response_class=PlainTextResponse)
async def create_roomid(
path,
request: Request,
response: Response,
user: User = Depends(current_user(permissions={"contents": ["read"]})),
):
# we need to process the request manually
# see https://github.com/tiangolo/fastapi/issues/3373#issuecomment-1306003451
create_room_id = CreateRoomId(**(await request.json()))
ws_url = f"{create_room_id.format}:{create_room_id.type}:"
idx = await get_file_id_manager().get_id(path)
if idx is not None:
return ws_url + idx

idx = await get_file_id_manager().index(path)
if idx is None:
raise HTTPException(status_code=404, detail=f"File {path} does not exist")

response.status_code = status.HTTP_201_CREATED
return ws_url + idx


r = register_router(router)
19 changes: 16 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import sys
from pathlib import Path

import pytest
import requests
Expand Down Expand Up @@ -55,21 +56,33 @@ def test_settings_persistence_get(start_jupyverse):
async def test_rest_api(start_jupyverse):
url = start_jupyverse
ws_url = url.replace("http", "ws", 1)
name = "notebook0.ipynb"
path = (Path("tests") / "data" / name).as_posix()
# create a session to launch a kernel
response = requests.post(
f"{url}/api/sessions",
data=json.dumps(
{
"kernel": {"name": "python3"},
"name": "notebook0.ipynb",
"path": "69e8a762-86c6-4102-a3da-a43d735fec2b",
"name": name,
"path": path,
"type": "notebook",
}
),
)
r = response.json()
kernel_id = r["kernel"]["id"]
document_id = "json:notebook:tests/data/notebook0.ipynb"
# get the room ID for the document
response = requests.put(
f"{url}/api/yjs/roomid/{path}",
data=json.dumps(
{
"format": "json",
"type": "notebook",
}
),
)
document_id = response.text
async with connect(f"{ws_url}/api/yjs/{document_id}") as websocket:
# connect to the shared notebook document
ydoc = Y.YDoc()
Expand Down

0 comments on commit a555dc1

Please sign in to comment.