Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use get_user_db dependency #60

Merged
merged 1 commit into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions plugins/auth/fps_auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ def get_user_manager(user_db=Depends(get_user_db)):


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response):
await super().get_login_response(user, response)
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
# set user as logged in
user.logged_in = True
await user_db.update(user)
await user_manager.user_db.update(user)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"

async def get_logout_response(self, user, response):
await super().get_logout_response(user, response)
async def get_logout_response(self, user, response, user_manager):
await super().get_logout_response(user, response, user_manager)
# set user as logged out
user.logged_in = False
await user_db.update(user)
await user_manager.user_db.update(user)


cookie_authentication = LoginCookieAuthentication(
Expand Down Expand Up @@ -153,6 +153,7 @@ def current_user(optional: bool = False):
async def _(
auth_config=Depends(get_auth_config),
user: User = Depends(users.current_user(optional=True)),
user_db=Depends(get_user_db),
):
if auth_config.mode == "noauth":
return await user_db.get_by_email(noauth_email)
Expand Down
2 changes: 1 addition & 1 deletion plugins/auth/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
packages=find_packages(),
install_requires=[
"fps",
"fastapi-users[sqlalchemy,oauth]==8",
"fastapi-users[sqlalchemy,oauth]>=8.1.0",
],
entry_points={
"fps_router": ["fps-auth = fps_auth.routes"],
Expand Down
9 changes: 7 additions & 2 deletions plugins/jupyterlab/fps_jupyterlab/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from starlette.requests import Request # type: ignore
from fps.hooks import register_router # type: ignore

from fps_auth.db import get_user_db # type: ignore
from fps_auth.routes import ( # type: ignore
current_user,
user_db,
cookie_authentication,
LoginCookieAuthentication,
get_user_manager,
)
from fps_auth.models import User # type: ignore
from fps_auth.config import get_auth_config # type: ignore
Expand Down Expand Up @@ -74,13 +75,15 @@ async def get_root(
token: Optional[UUID4] = None,
auth_config=Depends(get_auth_config),
jlab_config=Depends(get_jlab_config),
user_db=Depends(get_user_db),
user_manager=Depends(get_user_manager),
):
if token and auth_config.mode == "token":
user = await user_db.get(token)
if user:
await super(
LoginCookieAuthentication, cookie_authentication
).get_login_response(user, response)
).get_login_response(user, response, user_manager)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = jlab_config.base_url + "lab"
Expand Down Expand Up @@ -138,6 +141,7 @@ async def get_workspace_data(user: User = Depends(current_user(optional=True))):
async def set_workspace(
request: Request,
user: User = Depends(current_user()),
user_db=Depends(get_user_db),
):
user.workspace = await request.body()
await user_db.update(user)
Expand Down Expand Up @@ -245,6 +249,7 @@ async def change_setting(
name0,
name1,
user: User = Depends(current_user()),
user_db=Depends(get_user_db),
):
settings = json.loads(user.settings)
settings[f"{name0}:{name1}"] = await request.json()
Expand Down
8 changes: 6 additions & 2 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from fps_auth.routes import cookie_authentication, current_user # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .kernel_server.server import KernelServer # type: ignore
Expand Down Expand Up @@ -167,7 +167,11 @@ async def restart_kernel(

@router.websocket("/api/kernels/{kernel_id}/channels")
async def kernel_channels(
websocket: WebSocket, kernel_id, session_id, auth_config=Depends(get_auth_config)
websocket: WebSocket,
kernel_id,
session_id,
auth_config=Depends(get_auth_config),
user_db=Depends(get_user_db),
):
accept_websocket = False
if auth_config.mode == "noauth":
Expand Down
7 changes: 5 additions & 2 deletions plugins/terminals/fps_terminals/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fps_auth.routes import cookie_authentication, current_user # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .models import Terminal
Expand Down Expand Up @@ -51,7 +51,10 @@ async def delete_terminal(

@router.websocket("/terminals/websocket/{name}")
async def terminal_websocket(
websocket: WebSocket, name, auth_config=Depends(get_auth_config)
websocket: WebSocket,
name,
auth_config=Depends(get_auth_config),
user_db=Depends(get_user_db),
):
accept_websocket = False
if auth_config.mode == "noauth":
Expand Down
8 changes: 6 additions & 2 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import fastapi

from fps_auth.routes import cookie_authentication # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

router = APIRouter()
Expand All @@ -25,7 +25,11 @@ def get_path_param_names(path: str) -> Set[str]:

@router.websocket("/api/yjs/{type}:{path:path}")
async def websocket_endpoint(
websocket: WebSocket, type, path, auth_config=Depends(get_auth_config)
websocket: WebSocket,
type,
path,
auth_config=Depends(get_auth_config),
user_db=Depends(get_user_db),
):
accept_websocket = False
if auth_config.mode == "noauth":
Expand Down