2121from typing_extensions import NoReturn , Protocol
2222
2323from twisted .web .http import Request
24+ from twisted .web .iweb import IRequest
2425
2526from synapse .api .constants import LoginType
2627from synapse .api .errors import Codes , NotFoundError , RedirectException , SynapseError
2728from synapse .handlers .ui_auth import UIAuthSessionDataConstants
2829from synapse .http import get_request_user_agent
29- from synapse .http .server import respond_with_html
30+ from synapse .http .server import respond_with_html , respond_with_redirect
3031from synapse .http .site import SynapseRequest
3132from synapse .types import JsonDict , UserID , contains_invalid_mxid_characters
3233from synapse .util .async_helpers import Linearizer
@@ -141,6 +142,9 @@ class UsernameMappingSession:
141142 # expiry time for the session, in milliseconds
142143 expiry_time_ms = attr .ib (type = int )
143144
145+ # choices made by the user
146+ chosen_localpart = attr .ib (type = Optional [str ], default = None )
147+
144148
145149# the HTTP cookie used to track the mapping session id
146150USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -647,6 +651,25 @@ async def complete_sso_ui_auth_request(
647651 )
648652 respond_with_html (request , 200 , html )
649653
654+ def get_mapping_session (self , session_id : str ) -> UsernameMappingSession :
655+ """Look up the given username mapping session
656+
657+ If it is not found, raises a SynapseError with an http code of 400
658+
659+ Args:
660+ session_id: session to look up
661+ Returns:
662+ active mapping session
663+ Raises:
664+ SynapseError if the session is not found/has expired
665+ """
666+ self ._expire_old_sessions ()
667+ session = self ._username_mapping_sessions .get (session_id )
668+ if session :
669+ return session
670+ logger .info ("Couldn't find session id %s" , session_id )
671+ raise SynapseError (400 , "unknown session" )
672+
650673 async def check_username_availability (
651674 self , localpart : str , session_id : str ,
652675 ) -> bool :
@@ -663,12 +686,7 @@ async def check_username_availability(
663686
664687 # make sure that there is a valid mapping session, to stop people dictionary-
665688 # scanning for accounts
666-
667- self ._expire_old_sessions ()
668- session = self ._username_mapping_sessions .get (session_id )
669- if not session :
670- logger .info ("Couldn't find session id %s" , session_id )
671- raise SynapseError (400 , "unknown session" )
689+ self .get_mapping_session (session_id )
672690
673691 logger .info (
674692 "[session %s] Checking for availability of username %s" ,
@@ -696,16 +714,33 @@ async def handle_submit_username_request(
696714 localpart: localpart requested by the user
697715 session_id: ID of the username mapping session, extracted from a cookie
698716 """
699- self ._expire_old_sessions ()
700- session = self ._username_mapping_sessions .get (session_id )
701- if not session :
702- logger .info ("Couldn't find session id %s" , session_id )
703- raise SynapseError (400 , "unknown session" )
717+ session = self .get_mapping_session (session_id )
718+
719+ # update the session with the user's choices
720+ session .chosen_localpart = localpart
721+
722+ # we're done; now we can register the user
723+ respond_with_redirect (request , b"/_synapse/client/sso_register" )
724+
725+ async def register_sso_user (self , request : Request , session_id : str ) -> None :
726+ """Called once we have all the info we need to register a new user.
704727
705- logger .info ("[session %s] Registering localpart %s" , session_id , localpart )
728+ Does so and serves an HTTP response
729+
730+ Args:
731+ request: HTTP request
732+ session_id: ID of the username mapping session, extracted from a cookie
733+ """
734+ session = self .get_mapping_session (session_id )
735+
736+ logger .info (
737+ "[session %s] Registering localpart %s" ,
738+ session_id ,
739+ session .chosen_localpart ,
740+ )
706741
707742 attributes = UserAttributes (
708- localpart = localpart ,
743+ localpart = session . chosen_localpart ,
709744 display_name = session .display_name ,
710745 emails = session .emails ,
711746 )
@@ -720,7 +755,12 @@ async def handle_submit_username_request(
720755 request .getClientIP (),
721756 )
722757
723- logger .info ("[session %s] Registered userid %s" , session_id , user_id )
758+ logger .info (
759+ "[session %s] Registered userid %s with attributes %s" ,
760+ session_id ,
761+ user_id ,
762+ attributes ,
763+ )
724764
725765 # delete the mapping session and the cookie
726766 del self ._username_mapping_sessions [session_id ]
@@ -751,3 +791,14 @@ def _expire_old_sessions(self):
751791 for session_id in to_expire :
752792 logger .info ("Expiring mapping session %s" , session_id )
753793 del self ._username_mapping_sessions [session_id ]
794+
795+
796+ def get_username_mapping_session_cookie_from_request (request : IRequest ) -> str :
797+ """Extract the session ID from the cookie
798+
799+ Raises a SynapseError if the cookie isn't found
800+ """
801+ session_id = request .getCookie (USERNAME_MAPPING_SESSION_COOKIE_NAME )
802+ if not session_id :
803+ raise SynapseError (code = 400 , msg = "missing session_id" )
804+ return session_id .decode ("ascii" , errors = "replace" )
0 commit comments