Skip to content

Commit

Permalink
Update with fastapi-users v9
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 3, 2022
1 parent 5b446ad commit 696b3f3
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 38 deletions.
73 changes: 52 additions & 21 deletions plugins/auth/fps_auth/backends.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from uuid import uuid4
from typing import Optional
from typing import Optional, Generic

from fps.exceptions import RedirectException # type: ignore

import httpx
from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore
from fastapi import Depends, Response, HTTPException, status

from fastapi_users.authentication import CookieAuthentication, BaseAuthentication # type: ignore
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore
from fastapi_users.authentication import (
AuthenticationBackend,
CookieTransport,
JWTStrategy,
)
from fastapi_users.authentication.transport.base import Transport
from fastapi_users.authentication.strategy.base import Strategy
from fastapi_users import FastAPIUsers, BaseUserManager, models # type: ignore
from starlette.requests import Request

from fps.logging import get_configured_logger # type: ignore
Expand All @@ -20,30 +26,55 @@
logger = get_configured_logger("auth")


class NoAuthAuthentication(BaseAuthentication):
def __init__(self, name: str = "noauth"):
super().__init__(name, logout=False)
self.scheme = None # type: ignore
class NoAuthTransport(Transport):
scheme = None # type: ignore


async def __call__(self, credentials, user_manager):
class NoAuthStrategy(Strategy, Generic[models.UC, models.UD]):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
active_user = await user_manager.user_db.get_by_email(
get_auth_config().global_email
)
return active_user


class GitHubAuthentication(CookieAuthentication):
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
class GitHubTransport(CookieTransport):
async def get_login_response(self, token: str, response: Response):
await super().get_login_response(token, response)
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"


noauth_authentication = NoAuthAuthentication(name="noauth")
cookie_authentication = CookieAuthentication(
secret=secret, cookie_secure=get_auth_config().cookie_secure, name="cookie" # type: ignore
noauth_transport = NoAuthTransport()
cookie_transport = CookieTransport()
github_transport = GitHubTransport()


def get_noauth_strategy() -> NoAuthStrategy:
return NoAuthStrategy()


def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(secret=secret, lifetime_seconds=None)


noauth_authentication = AuthenticationBackend(
name="noauth",
transport=noauth_transport,
get_strategy=get_noauth_strategy,
)
cookie_authentication = AuthenticationBackend(
name="cookie",
transport=cookie_transport,
get_strategy=get_jwt_strategy,
)
github_cookie_authentication = AuthenticationBackend(
name="github",
transport=github_transport,
get_strategy=get_jwt_strategy,
)
github_cookie_authentication = GitHubAuthentication(secret=secret, name="github")
github_authentication = GitHubOAuth2(
get_auth_config().client_id, get_auth_config().client_secret.get_secret_value()
)
Expand Down Expand Up @@ -127,24 +158,24 @@ async def current_user(
if auth_config.collaborative:
if not active_user and auth_config.mode == "noauth":
active_user = await create_guest(user_db, auth_config)
await cookie_authentication.get_login_response(
active_user, response, user_manager
await noauth_authentication.login(
get_noauth_strategy(), active_user, response
)

elif not active_user and auth_config.mode == "token":
global_user = await user_db.get_by_email(auth_config.global_email)
if global_user and global_user.hashed_password == token:
active_user = await create_guest(user_db, auth_config)
await cookie_authentication.get_login_response(
active_user, response, user_manager
await cookie_authentication.login(
get_jwt_strategy(), active_user, response
)
else:
if auth_config.mode == "token":
global_user = await user_db.get_by_email(auth_config.global_email)
if global_user and global_user.hashed_password == token:
active_user = global_user
await cookie_authentication.get_login_response(
active_user, response, user_manager
await cookie_authentication.login(
get_jwt_strategy(), active_user, response
)

if active_user:
Expand Down
2 changes: 1 addition & 1 deletion plugins/auth/fps_auth/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pytest
import pytest # type: ignore
from fps_auth.config import AuthConfig, get_auth_config


Expand Down
6 changes: 5 additions & 1 deletion plugins/auth/fps_auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
fapi_users,
current_user,
cookie_authentication,
github_cookie_authentication,
github_authentication,
)
from .models import User, UserDB
Expand Down Expand Up @@ -81,7 +82,10 @@ async def get_users(user: User = Depends(current_user)):

# GitHub OAuth register router
r_github = register_router(
fapi_users.get_oauth_router(github_authentication, secret), prefix="/auth/github"
fapi_users.get_oauth_router(
github_authentication, github_cookie_authentication, secret
),
prefix="/auth/github",
)

r = register_router(router)
2 changes: 1 addition & 1 deletion plugins/auth/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ packages = find:
install_requires =
fps >=0.0.8
aiosqlite
fastapi-users[sqlalchemy,oauth] >=8.1.0,<9
fastapi-users[sqlalchemy,oauth] >=9.1.1,<10

[options.entry_points]
fps_router =
Expand Down
4 changes: 2 additions & 2 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.responses import FileResponse
from starlette.requests import Request # type: ignore

from fps_auth.backends import cookie_authentication, current_user # type: ignore
from fps_auth.backends import get_jwt_strategy, current_user # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore
Expand Down Expand Up @@ -194,7 +194,7 @@ async def kernel_channels(
accept_websocket = True
else:
cookie = websocket._cookies["fastapiusersauth"]
user = await cookie_authentication(cookie, user_db)
user = await get_jwt_strategy().read_token(cookie, user_db)
if user:
accept_websocket = True
if accept_websocket:
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def authenticated_user(client):
assert response.status_code == 201
# login with registered user
login_body = {"username": username + "@example.com", "password": username}
assert "fastapiusersauth" not in client.cookies
assert "fastapiusersauth" not in client.cookies.keys()
response = client.post("/auth/login", data=login_body)
assert "fastapiusersauth" in client.cookies
assert "fastapiusersauth" in client.cookies.keys()
# who am I?
response = client.get("/auth/user/me")
response = client.get("/auth/user/me", cookies=client.cookies.get_dict())
assert response.status_code != 401
return username

Expand Down
9 changes: 5 additions & 4 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_kernel_channels_unauthenticated(client):
def test_kernel_channels_authenticated(client, authenticated_user):
with client.websocket_connect(
"/api/kernels/kernel_id_0/channels?session_id=session_id_0",
cookies=client.cookies,
cookies=client.cookies.get_dict(),
):
pass

Expand All @@ -39,15 +39,16 @@ def test_root_auth(auth_mode, client, app):


@pytest.mark.parametrize("auth_mode", ("noauth",))
def test_no_auth(auth_mode, client, app):
def test_no_auth(client, app):
with TestClient(app) as client:
response = client.get("/lab/api/settings")
assert response.status_code == 200


@pytest.mark.parametrize("auth_mode", ("token",))
def test_token_auth(auth_mode, client, app):
def test_token_auth(client, app):
auth_config = get_auth_config()
with TestClient(app) as client:
auth_config = get_auth_config()
response = client.get(f"/?token={auth_config.token}")
response = client.get("/", cookies=client.cookies.get_dict())
assert response.status_code == 200
8 changes: 6 additions & 2 deletions tests/test_contents.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
from pathlib import Path

import pytest
from utils import clear_content_values, create_content, sort_content_by_name
from fastapi.testclient import TestClient


def test_tree(client, authenticated_user, tmp_path):
@pytest.mark.parametrize("auth_mode", ("noauth",))
def test_tree(client, app, tmp_path):
os.chdir(tmp_path)
dname = Path(".")
expected = []
Expand Down Expand Up @@ -48,7 +51,8 @@ def test_tree(client, authenticated_user, tmp_path):
path=str(dname),
format="json",
)
response = client.get("/api/contents", params={"content": 1})
with TestClient(app) as client:
response = client.get("/api/contents", params={"content": 1})
actual = response.json()
# ignore modification and creation times
clear_content_values(actual, keys=["created", "last_modified"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.mark.asyncio
@pytest.mark.parametrize("auth_mode", ("noauth",))
async def test_kernel_messages(auth_mode, client, capfd):
async def test_kernel_messages(client, capfd):
kernel_id = "kernel_id_0"
kernel_name = "python3"
kernelspec_path = (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@pytest.mark.parametrize("auth_mode", ("noauth",))
def test_put_settings(auth_mode, client, app):
def test_put_settings(client, app):
with TestClient(app) as client:
response = client.put(
"/lab/api/settings/@jupyterlab/apputils-extension:themes",
Expand All @@ -15,7 +15,7 @@ def test_put_settings(auth_mode, client, app):


@pytest.mark.parametrize("auth_mode", ("noauth",))
def test_get_settings(auth_mode, client, app):
def test_get_settings(client, app):
with TestClient(app) as client:
response = client.get("/lab/api/settings/@jupyterlab/apputils-extension:themes")
assert response.status_code == 200
Expand Down

0 comments on commit 696b3f3

Please sign in to comment.