diff --git a/plugins/auth/fps_auth/db.py b/plugins/auth/fps_auth/db.py new file mode 100644 index 00000000..a27d7709 --- /dev/null +++ b/plugins/auth/fps_auth/db.py @@ -0,0 +1,74 @@ +import secrets +from pathlib import Path + +from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase # type: ignore +from fastapi_users.db import SQLAlchemyBaseOAuthAccountTable # type: ignore +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore +from sqlalchemy import Boolean, String, Text, Column # type: ignore +import sqlalchemy # type: ignore +import databases # type: ignore +from fps.config import get_config # type: ignore + +from .config import AuthConfig +from .models import ( + UserDB, +) + +auth_config = get_config(AuthConfig) + +jupyter_dir = Path.home() / ".local" / "share" / "jupyter" +jupyter_dir.mkdir(parents=True, exist_ok=True) +secret_path = jupyter_dir / "jupyverse_secret" +userdb_path = jupyter_dir / "jupyverse_users.db" + +if auth_config.clear_users: + if userdb_path.is_file(): + userdb_path.unlink() + if secret_path.is_file(): + secret_path.unlink() + +if not secret_path.is_file(): + with open(secret_path, "w") as f: + f.write(secrets.token_hex(32)) + +with open(secret_path) as f: + secret = f.read() + + +DATABASE_URL = f"sqlite:///{userdb_path}" + +database = databases.Database(DATABASE_URL) + +Base: DeclarativeMeta = declarative_base() + + +class UserTable(Base, SQLAlchemyBaseUserTable): + initialized = Column(Boolean, default=False, nullable=False) + anonymous = Column(Boolean, default=False, nullable=False) + name = Column(String(length=32), nullable=True) + username = Column(String(length=32), nullable=True) + color = Column(String(length=32), nullable=True) + avatar = Column(String(length=32), nullable=True) + logged_in = Column(Boolean, default=False, nullable=False) + workspace = Column(Text(), nullable=False) + settings = Column(Text(), nullable=False) + + +class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base): + pass + + +engine = sqlalchemy.create_engine( + DATABASE_URL, connect_args={"check_same_thread": False} +) + +Base.metadata.create_all(engine) + +users = UserTable.__table__ +oauth_accounts = OAuthAccount.__table__ + +user_db = SQLAlchemyUserDatabase(UserDB, database, users, oauth_accounts) + + +def get_user_db(): + yield user_db diff --git a/plugins/auth/fps_auth/models.py b/plugins/auth/fps_auth/models.py index 96c3e656..36b1c2f1 100644 --- a/plugins/auth/fps_auth/models.py +++ b/plugins/auth/fps_auth/models.py @@ -1,20 +1,7 @@ -import secrets -from pathlib import Path from typing import Optional from pydantic import BaseModel -import databases # type: ignore -import sqlalchemy # type: ignore from fastapi_users import models # type: ignore -from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase # type: ignore -from fastapi_users.db import SQLAlchemyBaseOAuthAccountTable # type: ignore -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore -from sqlalchemy import Boolean, String, Text, Column -from fps.config import Config # type: ignore - -from .config import AuthConfig - -auth_config = Config(AuthConfig) class JupyterUser(BaseModel): @@ -45,55 +32,3 @@ class UserUpdate(models.BaseUserUpdate, JupyterUser): class UserDB(User, models.BaseUserDB): pass - - -jupyter_dir = Path.home() / ".local" / "share" / "jupyter" -jupyter_dir.mkdir(parents=True, exist_ok=True) -secret_path = jupyter_dir / "jupyverse_secret" -userdb_path = jupyter_dir / "jupyverse_users.db" - -if auth_config.clear_users: - if userdb_path.is_file(): - userdb_path.unlink() - if secret_path.is_file(): - secret_path.unlink() - -if not secret_path.is_file(): - with open(secret_path, "w") as f: - f.write(secrets.token_hex(32)) - -with open(secret_path) as f: - secret = f.read() - -DATABASE_URL = f"sqlite:///{userdb_path}" - -database = databases.Database(DATABASE_URL) - -Base: DeclarativeMeta = declarative_base() - - -class UserTable(Base, SQLAlchemyBaseUserTable): - initialized = Column(Boolean, default=False, nullable=False) - anonymous = Column(Boolean, default=False, nullable=False) - name = Column(String(length=32), nullable=True) - username = Column(String(length=32), nullable=True) - color = Column(String(length=32), nullable=True) - avatar = Column(String(length=32), nullable=True) - logged_in = Column(Boolean, default=False, nullable=False) - workspace = Column(Text(), nullable=False) - settings = Column(Text(), nullable=False) - - -class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base): - pass - - -engine = sqlalchemy.create_engine( - DATABASE_URL, connect_args={"check_same_thread": False} -) - -Base.metadata.create_all(engine) - -users = UserTable.__table__ -oauth_accounts = OAuthAccount.__table__ -user_db = SQLAlchemyUserDatabase(UserDB, database, users, oauth_accounts) diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index b05cfaba..a7b9e2dd 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -1,4 +1,5 @@ from uuid import uuid4 +from typing import Optional import httpx # type: ignore from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore @@ -6,21 +7,17 @@ from fps.config import get_config, FPSConfig # type: ignore from fastapi_users.authentication import CookieAuthentication # type: ignore from fastapi import APIRouter, Depends, HTTPException, status -from fastapi_users import FastAPIUsers # type: ignore +from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore from starlette.requests import Request from sqlalchemy.orm import sessionmaker # type: ignore from .config import get_auth_config +from .db import get_user_db, user_db, secret, database, engine, UserTable from .models import ( - user_db, - engine, - UserTable, User, UserCreate, UserUpdate, UserDB, - database, - secret, ) @@ -31,6 +28,29 @@ auth_config = get_auth_config() +class UserManager(BaseUserManager[UserCreate, UserDB]): + user_db_model = UserDB + + async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + user.initialized = True + for oauth_account in user.oauth_accounts: + print(oauth_account) + if oauth_account.oauth_name == "github": + r = httpx.get( + f"https://api.github.com/user/{oauth_account.account_id}" + ).json() + user.anonymous = False + user.username = r["login"] + user.name = r["name"] + user.color = None + user.avatar = r["avatar_url"] + await self.user_db.update(user) + + +def get_user_manager(user_db=Depends(get_user_db)): + yield UserManager(user_db) + + class LoginCookieAuthentication(CookieAuthentication): async def get_login_response(self, user, response): await super().get_login_response(user, response) @@ -55,7 +75,7 @@ async def get_logout_response(self, user, response): auth_backends = [cookie_authentication] users = FastAPIUsers( - user_db, + get_user_manager, auth_backends, User, UserCreate, @@ -68,29 +88,11 @@ async def get_logout_response(self, user, response): ) -async def on_after_register(user: UserDB, request): - user.initialized = True - await user_db.update(user) - - -async def on_after_github_register(user: UserDB, request: Request): - r = httpx.get( - f"https://api.github.com/user/{user.oauth_accounts[0].account_id}" - ).json() - user.initialized = True - user.anonymous = False - user.username = r["login"] - user.name = r["name"] - user.color = None - user.avatar = r["avatar_url"] - await user_db.update(user) - - github_oauth_router = users.get_oauth_router( - github_oauth_client, secret, after_register=on_after_github_register # type: ignore + github_oauth_client, secret # type: ignore ) auth_router = users.get_auth_router(cookie_authentication) -user_register_router = users.get_register_router(on_after_register) # type: ignore +user_register_router = users.get_register_router() # type: ignore users_router = users.get_users_router() router = APIRouter() diff --git a/plugins/auth/setup.py b/plugins/auth/setup.py index f6ab03a5..2d292ba4 100644 --- a/plugins/auth/setup.py +++ b/plugins/auth/setup.py @@ -6,7 +6,7 @@ packages=find_packages(), install_requires=[ "fps", - "fastapi-users[sqlalchemy]>=7.0.0", + "fastapi-users[sqlalchemy]>=8.0.0", "httpx-oauth", "aiosqlite", ], diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index 43176d83..bf4035d6 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -10,7 +10,8 @@ from starlette.requests import Request # type: ignore from fps_auth.routes import cookie_authentication, current_user # type: ignore -from fps_auth.models import User, user_db # type: ignore +from fps_auth.models import User # type: ignore +from fps_auth.db import user_db # type: ignore from fps_auth.config import get_auth_config # type: ignore from .kernel_server.server import KernelServer # type: ignore diff --git a/plugins/terminals/fps_terminals/routes.py b/plugins/terminals/fps_terminals/routes.py index 60ad52bf..17f11244 100644 --- a/plugins/terminals/fps_terminals/routes.py +++ b/plugins/terminals/fps_terminals/routes.py @@ -6,7 +6,8 @@ from fastapi import APIRouter, WebSocket, Response, Depends, status from fps_auth.routes import cookie_authentication, current_user # type: ignore -from fps_auth.models import User, user_db # type: ignore +from fps_auth.models import User # type: ignore +from fps_auth.db import user_db # type: ignore from fps_auth.config import get_auth_config # type: ignore from .models import Terminal diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 95962181..c9ac2737 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -9,7 +9,7 @@ import fastapi from fps_auth.routes import cookie_authentication # type: ignore -from fps_auth.models import user_db # type: ignore +from fps_auth.db import user_db # type: ignore from fps_auth.config import get_auth_config # type: ignore router = APIRouter()