11import uuid
2- from typing import Generic , Optional
2+ from typing import Any , Dict , Generic , List , Optional , Tuple
33
44import httpx
55from fastapi import Depends , HTTPException , Response , WebSocket , status
1919from fastapi_users .db import SQLAlchemyUserDatabase
2020from fps .exceptions import RedirectException # type: ignore
2121from fps .logging import get_configured_logger # type: ignore
22+ from fps_lab .config import get_lab_config # type: ignore
2223from httpx_oauth .clients .github import GitHubOAuth2 # type: ignore
2324from starlette .requests import Request
2425
2526from .config import get_auth_config
2627from .db import User , get_user_db , secret
27- from .models import UserCreate
28+ from .models import UserCreate , UserRead
2829
2930logger = get_configured_logger ("auth" )
3031
@@ -106,8 +107,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
106107 yield UserManager (user_db )
107108
108109
109- async def get_enabled_backends (auth_config = Depends (get_auth_config )):
110- if auth_config .mode == "noauth" and not auth_config .collaborative :
110+ async def get_enabled_backends (
111+ auth_config = Depends (get_auth_config ), lab_config = Depends (get_lab_config )
112+ ):
113+ if auth_config .mode == "noauth" and not lab_config .collaborative :
111114 return [noauth_authentication , github_cookie_authentication ]
112115 else :
113116 return [cookie_authentication , github_cookie_authentication ]
@@ -131,35 +134,33 @@ async def create_guest(user_manager, auth_config):
131134 password = "" ,
132135 workspace = global_user .workspace ,
133136 settings = global_user .settings ,
137+ permissions = {},
134138 )
135139 return await user_manager .create (UserCreate (** guest ))
136140
137141
138- def current_user (resource : Optional [str ] = None ):
142+ def current_user (permissions : Optional [Dict [ str , List [ str ]] ] = None ):
139143 async def _ (
140- request : Request ,
141144 response : Response ,
142145 token : Optional [str ] = None ,
143146 user : Optional [User ] = Depends (
144147 fapi_users .current_user (optional = True , get_enabled_backends = get_enabled_backends )
145148 ),
146149 user_manager : UserManager = Depends (get_user_manager ),
147150 auth_config = Depends (get_auth_config ),
151+ lab_config = Depends (get_lab_config ),
148152 ):
149153 if auth_config .mode == "user" :
150154 # "user" authentication: check authorization
151- if user and resource :
152- # check if allowed to access the resource
153- permissions = user .permissions .get (resource , [])
154- if request .method in ("GET" , "HEAD" ):
155- if "read" not in permissions :
156- user = None
157- elif request .method in ("POST" , "PUT" , "PATCH" , "DELETE" ):
158- if "write" not in permissions :
155+ if user and permissions :
156+ for resource , actions in permissions .items ():
157+ user_actions_for_resource = user .permissions .get (resource , [])
158+ if not all ([a in user_actions_for_resource for a in actions ]):
159159 user = None
160+ break
160161 else :
161162 # "noauth" or "token" authentication
162- if auth_config .collaborative :
163+ if lab_config .collaborative :
163164 if not user and auth_config .mode == "noauth" :
164165 user = await create_guest (user_manager , auth_config )
165166 await cookie_authentication .login (get_jwt_strategy (), user , response )
@@ -188,28 +189,60 @@ async def _(
188189 return _
189190
190191
191- def websocket_for_current_user (resource : str ):
192+ def websocket_auth (permissions : Optional [Dict [str , List [str ]]] = None ):
193+ """
194+ A function returning a dependency for the WebSocket connection.
195+
196+ :param permissions: the permissions the user should be granted access to. The user should have
197+ access to at least one of them for the WebSocket to be opened.
198+ :returns: a dependency for the WebSocket connection. The dependency returns a tuple consisting
199+ of the websocket and the checked user permissions if the websocket is accepted, None otherwise.
200+ """
201+
192202 async def _ (
193203 websocket : WebSocket ,
194204 auth_config = Depends (get_auth_config ),
195205 user_manager : UserManager = Depends (get_user_manager ),
196- ) -> Optional [WebSocket ]:
206+ ) -> Optional [Tuple [ WebSocket , Optional [ Dict [ str , List [ str ]]]] ]:
197207 accept_websocket = False
208+ checked_permissions : Optional [Dict [str , List [str ]]] = None
198209 if auth_config .mode == "noauth" :
199210 accept_websocket = True
200211 elif "fastapiusersauth" in websocket ._cookies :
201212 token = websocket ._cookies ["fastapiusersauth" ]
202213 user = await get_jwt_strategy ().read_token (token , user_manager )
203214 if user :
204215 if auth_config .mode == "user" :
205- if "execute" in user .permissions .get (resource , []):
216+ # "user" authentication: check authorization
217+ if permissions is None :
206218 accept_websocket = True
219+ else :
220+ checked_permissions = {}
221+ for resource , actions in permissions .items ():
222+ user_actions_for_resource = user .permissions .get (resource )
223+ if user_actions_for_resource is None :
224+ continue
225+ allowed = checked_permissions [resource ] = []
226+ for action in actions :
227+ if action in user_actions_for_resource :
228+ allowed .append (action )
229+ accept_websocket = True
207230 else :
208231 accept_websocket = True
209232 if accept_websocket :
210- return websocket
233+ return websocket , checked_permissions
211234 else :
212235 await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
213236 return None
214237
215238 return _
239+
240+
241+ async def update_user (
242+ user : UserRead = Depends (current_user ()), user_db : SQLAlchemyUserDatabase = Depends (get_user_db )
243+ ):
244+ async def _ (data : Dict [str , Any ]) -> UserRead :
245+ await user_db .update (user , data )
246+ return user
247+
248+ return _
0 commit comments