diff --git a/plugins/auth/fps_auth/backends.py b/plugins/auth/fps_auth/backends.py index d5d1d271..1cb9d55f 100644 --- a/plugins/auth/fps_auth/backends.py +++ b/plugins/auth/fps_auth/backends.py @@ -1,5 +1,5 @@ from uuid import uuid4 -from typing import Optional +from typing import Optional, Generic from fps.exceptions import RedirectException # type: ignore @@ -7,8 +7,14 @@ from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore from fastapi import Depends, Response, HTTPException, status -from fastapi_users.authentication import CookieAuthentication, BaseAuthentication # type: ignore -from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore +from fastapi_users.authentication import ( + AuthenticationBackend, + CookieTransport, + JWTStrategy, +) +from fastapi_users.authentication.transport.base import Transport +from fastapi_users.authentication.strategy.base import Strategy +from fastapi_users import FastAPIUsers, BaseUserManager, models # type: ignore from starlette.requests import Request from fps.logging import get_configured_logger # type: ignore @@ -20,30 +26,55 @@ logger = get_configured_logger("auth") -class NoAuthAuthentication(BaseAuthentication): - def __init__(self, name: str = "noauth"): - super().__init__(name, logout=False) - self.scheme = None # type: ignore +class NoAuthTransport(Transport): + scheme = None # type: ignore + - async def __call__(self, credentials, user_manager): +class NoAuthStrategy(Strategy, Generic[models.UC, models.UD]): + async def read_token( + self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] + ) -> Optional[models.UD]: active_user = await user_manager.user_db.get_by_email( get_auth_config().global_email ) return active_user -class GitHubAuthentication(CookieAuthentication): - async def get_login_response(self, user, response, user_manager): - await super().get_login_response(user, response, user_manager) +class GitHubTransport(CookieTransport): + async def get_login_response(self, token: str, response: Response): + await super().get_login_response(token, response) response.status_code = status.HTTP_302_FOUND response.headers["Location"] = "/lab" -noauth_authentication = NoAuthAuthentication(name="noauth") -cookie_authentication = CookieAuthentication( - secret=secret, cookie_secure=get_auth_config().cookie_secure, name="cookie" # type: ignore +noauth_transport = NoAuthTransport() +cookie_transport = CookieTransport() +github_transport = GitHubTransport() + + +def get_noauth_strategy() -> NoAuthStrategy: + return NoAuthStrategy() + + +def get_jwt_strategy() -> JWTStrategy: + return JWTStrategy(secret=secret, lifetime_seconds=None) + + +noauth_authentication = AuthenticationBackend( + name="noauth", + transport=noauth_transport, + get_strategy=get_noauth_strategy, +) +cookie_authentication = AuthenticationBackend( + name="cookie", + transport=cookie_transport, + get_strategy=get_jwt_strategy, +) +github_cookie_authentication = AuthenticationBackend( + name="github", + transport=github_transport, + get_strategy=get_jwt_strategy, ) -github_cookie_authentication = GitHubAuthentication(secret=secret, name="github") github_authentication = GitHubOAuth2( get_auth_config().client_id, get_auth_config().client_secret.get_secret_value() ) @@ -127,24 +158,24 @@ async def current_user( if auth_config.collaborative: if not active_user and auth_config.mode == "noauth": active_user = await create_guest(user_db, auth_config) - await cookie_authentication.get_login_response( - active_user, response, user_manager + await noauth_authentication.login( + get_noauth_strategy(), active_user, response ) elif not active_user and auth_config.mode == "token": global_user = await user_db.get_by_email(auth_config.global_email) if global_user and global_user.hashed_password == token: active_user = await create_guest(user_db, auth_config) - await cookie_authentication.get_login_response( - active_user, response, user_manager + await cookie_authentication.login( + get_jwt_strategy(), active_user, response ) else: if auth_config.mode == "token": global_user = await user_db.get_by_email(auth_config.global_email) if global_user and global_user.hashed_password == token: active_user = global_user - await cookie_authentication.get_login_response( - active_user, response, user_manager + await cookie_authentication.login( + get_jwt_strategy(), active_user, response ) if active_user: diff --git a/plugins/auth/fps_auth/fixtures.py b/plugins/auth/fps_auth/fixtures.py index 4646f3de..279d2aa0 100644 --- a/plugins/auth/fps_auth/fixtures.py +++ b/plugins/auth/fps_auth/fixtures.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore from fps_auth.config import AuthConfig, get_auth_config diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index 19f7f792..04607214 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -14,6 +14,7 @@ fapi_users, current_user, cookie_authentication, + github_cookie_authentication, github_authentication, ) from .models import User, UserDB @@ -81,7 +82,10 @@ async def get_users(user: User = Depends(current_user)): # GitHub OAuth register router r_github = register_router( - fapi_users.get_oauth_router(github_authentication, secret), prefix="/auth/github" + fapi_users.get_oauth_router( + github_authentication, github_cookie_authentication, secret + ), + prefix="/auth/github", ) r = register_router(router) diff --git a/plugins/auth/setup.cfg b/plugins/auth/setup.cfg index 4174f162..91c970cb 100644 --- a/plugins/auth/setup.cfg +++ b/plugins/auth/setup.cfg @@ -9,7 +9,7 @@ packages = find: install_requires = fps >=0.0.8 aiosqlite - fastapi-users[sqlalchemy,oauth] >=8.1.0,<9 + fastapi-users[sqlalchemy,oauth] >=9.1.1,<10 [options.entry_points] fps_router = diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index 53e2ad66..4c9e66a3 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -9,7 +9,7 @@ from fastapi.responses import FileResponse from starlette.requests import Request # type: ignore -from fps_auth.backends import cookie_authentication, current_user # type: ignore +from fps_auth.backends import get_jwt_strategy, current_user # type: ignore from fps_auth.models import User # type: ignore from fps_auth.db import get_user_db # type: ignore from fps_auth.config import get_auth_config # type: ignore @@ -194,7 +194,7 @@ async def kernel_channels( accept_websocket = True else: cookie = websocket._cookies["fastapiusersauth"] - user = await cookie_authentication(cookie, user_db) + user = await get_jwt_strategy().read_token(cookie, user_db) if user: accept_websocket = True if accept_websocket: diff --git a/tests/conftest.py b/tests/conftest.py index 0b5e3257..6c08ded5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,11 +39,11 @@ def authenticated_user(client): assert response.status_code == 201 # login with registered user login_body = {"username": username + "@example.com", "password": username} - assert "fastapiusersauth" not in client.cookies + assert "fastapiusersauth" not in client.cookies.keys() response = client.post("/auth/login", data=login_body) - assert "fastapiusersauth" in client.cookies + assert "fastapiusersauth" in client.cookies.keys() # who am I? - response = client.get("/auth/user/me") + response = client.get("/auth/user/me", cookies=client.cookies.get_dict()) assert response.status_code != 401 return username diff --git a/tests/test_auth.py b/tests/test_auth.py index 3946720c..3339e962 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -19,7 +19,7 @@ def test_kernel_channels_unauthenticated(client): def test_kernel_channels_authenticated(client, authenticated_user): with client.websocket_connect( "/api/kernels/kernel_id_0/channels?session_id=session_id_0", - cookies=client.cookies, + cookies=client.cookies.get_dict(), ): pass @@ -39,15 +39,16 @@ def test_root_auth(auth_mode, client, app): @pytest.mark.parametrize("auth_mode", ("noauth",)) -def test_no_auth(auth_mode, client, app): +def test_no_auth(client, app): with TestClient(app) as client: response = client.get("/lab/api/settings") assert response.status_code == 200 @pytest.mark.parametrize("auth_mode", ("token",)) -def test_token_auth(auth_mode, client, app): +def test_token_auth(client, app): + auth_config = get_auth_config() with TestClient(app) as client: - auth_config = get_auth_config() response = client.get(f"/?token={auth_config.token}") + response = client.get("/", cookies=client.cookies.get_dict()) assert response.status_code == 200 diff --git a/tests/test_contents.py b/tests/test_contents.py index 7cc44c22..bea76df3 100644 --- a/tests/test_contents.py +++ b/tests/test_contents.py @@ -1,10 +1,13 @@ import os from pathlib import Path +import pytest from utils import clear_content_values, create_content, sort_content_by_name +from fastapi.testclient import TestClient -def test_tree(client, authenticated_user, tmp_path): +@pytest.mark.parametrize("auth_mode", ("noauth",)) +def test_tree(client, app, tmp_path): os.chdir(tmp_path) dname = Path(".") expected = [] @@ -48,7 +51,8 @@ def test_tree(client, authenticated_user, tmp_path): path=str(dname), format="json", ) - response = client.get("/api/contents", params={"content": 1}) + with TestClient(app) as client: + response = client.get("/api/contents", params={"content": 1}) actual = response.json() # ignore modification and creation times clear_content_values(actual, keys=["created", "last_modified"]) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index ab6f998e..93782429 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -7,7 +7,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("auth_mode", ("noauth",)) -async def test_kernel_messages(auth_mode, client, capfd): +async def test_kernel_messages(client, capfd): kernel_id = "kernel_id_0" kernel_name = "python3" kernelspec_path = ( diff --git a/tests/test_settings.py b/tests/test_settings.py index cc48f892..48721a37 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize("auth_mode", ("noauth",)) -def test_put_settings(auth_mode, client, app): +def test_put_settings(client, app): with TestClient(app) as client: response = client.put( "/lab/api/settings/@jupyterlab/apputils-extension:themes", @@ -15,7 +15,7 @@ def test_put_settings(auth_mode, client, app): @pytest.mark.parametrize("auth_mode", ("noauth",)) -def test_get_settings(auth_mode, client, app): +def test_get_settings(client, app): with TestClient(app) as client: response = client.get("/lab/api/settings/@jupyterlab/apputils-extension:themes") assert response.status_code == 200