20
20
from http .cookies import Morsel
21
21
22
22
from tornado import escape , httputil , web
23
- from traitlets import Bool , Dict , Type , Unicode , default
23
+ from traitlets import Bool , Dict , Enum , List , TraitError , Type , Unicode , default , validate
24
24
from traitlets .config import LoggingConfigurable
25
25
26
26
from jupyter_server .transutils import _i18n
31
31
_non_alphanum = re .compile (r"[^A-Za-z0-9]" )
32
32
33
33
34
+ # Define the User properties that can be updated
35
+ UpdatableField = t .Literal ["name" , "display_name" , "initials" , "avatar_url" , "color" ]
36
+
37
+
34
38
@dataclass
35
39
class User :
36
40
"""Object representing a User
@@ -188,6 +192,14 @@ class IdentityProvider(LoggingConfigurable):
188
192
help = _i18n ("The logout handler class to use." ),
189
193
)
190
194
195
+ # Define the fields that can be updated
196
+ updatable_fields = List (
197
+ trait = Enum (list (t .get_args (UpdatableField ))),
198
+ default_value = ["color" ], # Default updatable field
199
+ config = True ,
200
+ help = _i18n ("List of fields in the User model that can be updated." ),
201
+ )
202
+
191
203
token_generated = False
192
204
193
205
@default ("token" )
@@ -207,6 +219,18 @@ def _token_default(self):
207
219
self .token_generated = True
208
220
return binascii .hexlify (os .urandom (24 )).decode ("ascii" )
209
221
222
+ @validate ("updatable_fields" )
223
+ def _validate_updatable_fields (self , proposal ):
224
+ """Validate that all fields in updatable_fields are valid."""
225
+ valid_updatable_fields = list (t .get_args (UpdatableField ))
226
+ invalid_fields = [
227
+ field for field in proposal ["value" ] if field not in valid_updatable_fields
228
+ ]
229
+ if invalid_fields :
230
+ msg = f"Invalid fields in updatable_fields: { invalid_fields } "
231
+ raise TraitError (msg )
232
+ return proposal ["value" ]
233
+
210
234
need_token : bool | Bool [bool , t .Union [bool , int ]] = Bool (True )
211
235
212
236
def get_user (self , handler : web .RequestHandler ) -> User | None | t .Awaitable [User | None ]:
@@ -269,6 +293,31 @@ async def _get_user(self, handler: web.RequestHandler) -> User | None:
269
293
270
294
return user
271
295
296
+ def update_user (
297
+ self , handler : web .RequestHandler , user_data : dict [UpdatableField , str ]
298
+ ) -> User :
299
+ """Update user information and persist the user model."""
300
+ self .check_update (user_data )
301
+ current_user = t .cast (User , handler .current_user )
302
+ updated_user = self .update_user_model (current_user , user_data )
303
+ self .persist_user_model (handler )
304
+ return updated_user
305
+
306
+ def check_update (self , user_data : dict [UpdatableField , str ]) -> None :
307
+ """Raises if some fields to update are not updatable."""
308
+ for field in user_data :
309
+ if field not in self .updatable_fields :
310
+ msg = f"Field { field } is not updatable"
311
+ raise ValueError (msg )
312
+
313
+ def update_user_model (self , current_user : User , user_data : dict [UpdatableField , str ]) -> User :
314
+ """Update user information."""
315
+ raise NotImplementedError
316
+
317
+ def persist_user_model (self , handler : web .RequestHandler ) -> None :
318
+ """Persist the user model (i.e. a cookie)."""
319
+ raise NotImplementedError
320
+
272
321
def identity_model (self , user : User ) -> dict [str , t .Any ]:
273
322
"""Return a User as an Identity model"""
274
323
# TODO: validate?
@@ -617,6 +666,16 @@ class PasswordIdentityProvider(IdentityProvider):
617
666
def _need_token_default (self ):
618
667
return not bool (self .hashed_password )
619
668
669
+ @default ("updatable_fields" )
670
+ def _default_updatable_fields (self ):
671
+ return [
672
+ "name" ,
673
+ "display_name" ,
674
+ "initials" ,
675
+ "avatar_url" ,
676
+ "color" ,
677
+ ]
678
+
620
679
@property
621
680
def login_available (self ) -> bool :
622
681
"""Whether a LoginHandler is needed - and therefore whether the login page should be displayed."""
@@ -627,6 +686,17 @@ def auth_enabled(self) -> bool:
627
686
"""Return whether any auth is enabled"""
628
687
return bool (self .hashed_password or self .token )
629
688
689
+ def update_user_model (self , current_user : User , user_data : dict [UpdatableField , str ]) -> User :
690
+ """Update user information."""
691
+ for field in self .updatable_fields :
692
+ if field in user_data :
693
+ setattr (current_user , field , user_data [field ])
694
+ return current_user
695
+
696
+ def persist_user_model (self , handler : web .RequestHandler ) -> None :
697
+ """Persist the user model to a cookie."""
698
+ self .set_login_cookie (handler , handler .current_user )
699
+
630
700
def passwd_check (self , password ):
631
701
"""Check password against our stored hashed password"""
632
702
return passwd_check (self .hashed_password , password )
0 commit comments