Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fastapi-users>=10 #179

Merged
merged 1 commit into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:

- name: Install jupyverse
run: |
mkdir fps && cd fps && curl -L -O https://github.com/jupyter-server/fps/archive/master.tar.gz && tar zxf master.tar.gz && cd fps-master && pip install . && pip install ./plugins/uvicorn && cd ../.. && rm -rf fps
pip install fps[uvicorn]
pip install . --no-deps
pip install ./plugins/auth
pip install ./plugins/contents
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ When switching e.g. from the JupyterLab to the RetroLab front-end, you need to
Clone this repository and install the needed plugins:

```bash
pip install fps[uvicorn]
pip install -e . --no-deps
pip install -e plugins/jupyterlab
pip install -e plugins/login
Expand All @@ -51,9 +52,6 @@ pip install -e plugins/lab
pip install -e plugins/nbconvert
pip install -e plugins/yjs

# you should also install the latest FPS:
pip install git+https://github.com/jupyter-server/fps

# if you want RetroLab instead of JupyterLab:
# pip install -e . --no-deps
# pip install -e plugins/retrolab
Expand Down
62 changes: 31 additions & 31 deletions plugins/auth/fps_auth/backends.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import uuid
from typing import Generic, Optional
from uuid import uuid4

import httpx
from fastapi import Depends, HTTPException, Response, status
from fastapi_users import BaseUserManager, FastAPIUsers, models # type: ignore
from fastapi_users import ( # type: ignore
BaseUserManager,
FastAPIUsers,
UUIDIDMixin,
models,
)
from fastapi_users.authentication import (
AuthenticationBackend,
CookieTransport,
JWTStrategy,
)
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users.authentication.transport.base import Transport
from fastapi_users.db import SQLAlchemyUserDatabase
from fps.exceptions import RedirectException # type: ignore
from fps.logging import get_configured_logger # type: ignore
from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore
from starlette.requests import Request

from .config import get_auth_config
from .db import get_user_db, secret
from .models import User, UserCreate, UserDB, UserUpdate
from .db import User, get_user_db, secret

logger = get_configured_logger("auth")

Expand All @@ -27,10 +32,10 @@ class NoAuthTransport(Transport):
scheme = None # type: ignore


class NoAuthStrategy(Strategy, Generic[models.UC, models.UD]):
class NoAuthStrategy(Strategy, Generic[models.UP, models.ID]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
active_user = await user_manager.user_db.get_by_email(get_auth_config().global_email)
return active_user

Expand Down Expand Up @@ -75,29 +80,28 @@ def get_jwt_strategy() -> JWTStrategy:
)


class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB

async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(self, user: User, request: Optional[Request] = None):
for oauth_account in user.oauth_accounts:
if oauth_account.oauth_name == "github":
async with httpx.AsyncClient() as client:
r = (
await client.get(f"https://api.github.com/user/{oauth_account.account_id}")
).json()

user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
user.workspace = "{}"
user.settings = "{}"

await self.user_db.update(user)
await self.user_db.update(
user,
dict(
anonymous=False,
username=r["login"],
color=None,
avatar=r["avatar_url"],
is_active=True,
),
)


def get_user_manager(user_db=Depends(get_user_db)):
def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
yield UserManager(user_db)


Expand All @@ -108,20 +112,18 @@ async def get_enabled_backends(auth_config=Depends(get_auth_config)):
return [cookie_authentication, github_cookie_authentication]


fapi_users = FastAPIUsers(
fapi_users = FastAPIUsers[User, uuid.UUID](
get_user_manager,
[noauth_authentication, cookie_authentication, github_cookie_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)


async def create_guest(user_db, auth_config):
# workspace and settings are copied from global user
# but this is a new user
global_user = await user_db.get_by_email(auth_config.global_email)
user_id = str(uuid4())
guest = UserDB(
user_id = str(uuid.uuid4())
guest = dict(
id=user_id,
anonymous=True,
email=f"{user_id}@jupyter.com",
Expand All @@ -130,8 +132,7 @@ async def create_guest(user_db, auth_config):
workspace=global_user.workspace,
settings=global_user.settings,
)
await user_db.create(guest)
return guest
return await user_db.create(guest)


async def current_user(
Expand All @@ -141,7 +142,6 @@ async def current_user(
fapi_users.current_user(optional=True, get_enabled_backends=get_enabled_backends)
),
user_db=Depends(get_user_db),
user_manager: UserManager = Depends(get_user_manager),
auth_config=Depends(get_auth_config),
):
active_user = user
Expand Down
53 changes: 33 additions & 20 deletions plugins/auth/fps_auth/db.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import secrets
from pathlib import Path
from typing import AsyncGenerator, List

import databases # type: ignore
import sqlalchemy # type: ignore
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTable # type: ignore
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID # type: ignore
from fastapi_users.db import ( # type: ignore
SQLAlchemyBaseUserTable,
SQLAlchemyBaseUserTableUUID,
SQLAlchemyUserDatabase,
)
from fps.config import get_config # type: ignore
from sqlalchemy import Boolean, Column, String, Text # type: ignore
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine # type: ignore
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore
from sqlalchemy.orm import relationship, sessionmaker # type: ignore

from .config import AuthConfig
from .models import UserDB

auth_config = get_config(AuthConfig)

Expand All @@ -36,37 +37,49 @@
secret = f.read()


DATABASE_URL = f"sqlite:///{userdb_path}"
DATABASE_URL = f"sqlite+aiosqlite:///{userdb_path}"
Base: DeclarativeMeta = declarative_base()

database = databases.Database(DATABASE_URL)

Base: DeclarativeMeta = declarative_base()
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass


class UserTable(Base, SQLAlchemyBaseUserTable):
class User(SQLAlchemyBaseUserTableUUID, Base):
anonymous = Column(Boolean, default=True, nullable=False)
email = Column(String(length=32), nullable=False, unique=True)
username = Column(String(length=32), nullable=True, unique=True)
name = Column(String(length=32), nullable=True)
color = Column(String(length=32), nullable=True)
avatar = Column(String(length=32), nullable=True)
workspace = Column(Text(), nullable=False)
settings = Column(Text(), nullable=False)
workspace = Column(Text(), default="{}", nullable=False)
settings = Column(Text(), default="{}", nullable=False)
oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined")


class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base):
pass
engine = create_async_engine(DATABASE_URL)
Session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)


async def create_db_and_tables():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)


engine = sqlalchemy.create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with Session() as session:
yield session

Base.metadata.create_all(engine)

users = UserTable.__table__
oauth_accounts = OAuthAccount.__table__
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)

user_db = SQLAlchemyUserDatabase(UserDB, database, users, oauth_accounts)

class UserDb:
async def __aenter__(self):
self.session = Session()
session = await self.session.__aenter__()
return SQLAlchemyUserDatabase(session, User, OAuthAccount)

def get_user_db():
yield user_db
async def __aexit__(self, exc_type, exc_value, exc_tb):
return await self.session.__aexit__(exc_type, exc_value, exc_tb)
13 changes: 5 additions & 8 deletions plugins/auth/fps_auth/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from typing import Optional

from fastapi_users import models # type: ignore
from fastapi_users import schemas
from pydantic import BaseModel


Expand All @@ -14,20 +15,16 @@ class JupyterUser(BaseModel):
settings: str = "{}"


class User(models.BaseUser, models.BaseOAuthAccountMixin, JupyterUser):
class UserRead(schemas.BaseUser[uuid.UUID], JupyterUser):
pass


class UserCreate(models.BaseUserCreate):
class UserCreate(schemas.BaseUserCreate):
anonymous: bool = True
username: Optional[str] = None
name: Optional[str] = None
color: Optional[str] = None


class UserUpdate(models.BaseUserUpdate, JupyterUser):
pass


class UserDB(User, models.BaseUserDB):
class UserUpdate(schemas.BaseUserUpdate, JupyterUser):
pass
Loading