3434from twisted .web .client import readBody
3535
3636from synapse .config import ConfigError
37- from synapse .http .server import respond_with_html
37+ from synapse .handlers ._base import BaseHandler
38+ from synapse .handlers .sso import MappingException
3839from synapse .http .site import SynapseRequest
3940from synapse .logging .context import make_deferred_yieldable
4041from synapse .types import JsonDict , UserID , map_username_to_mxid_localpart
@@ -83,17 +84,12 @@ def __str__(self):
8384 return self .error
8485
8586
86- class MappingException (Exception ):
87- """Used to catch errors when mapping the UserInfo object
88- """
89-
90-
91- class OidcHandler :
87+ class OidcHandler (BaseHandler ):
9288 """Handles requests related to the OpenID Connect login flow.
9389 """
9490
9591 def __init__ (self , hs : "HomeServer" ):
96- self . hs = hs
92+ super (). __init__ ( hs )
9793 self ._callback_url = hs .config .oidc_callback_url # type: str
9894 self ._scopes = hs .config .oidc_scopes # type: List[str]
9995 self ._user_profile_method = hs .config .oidc_user_profile_method # type: str
@@ -120,36 +116,13 @@ def __init__(self, hs: "HomeServer"):
120116 self ._http_client = hs .get_proxied_http_client ()
121117 self ._auth_handler = hs .get_auth_handler ()
122118 self ._registration_handler = hs .get_registration_handler ()
123- self ._datastore = hs .get_datastore ()
124- self ._clock = hs .get_clock ()
125- self ._hostname = hs .hostname # type: str
126119 self ._server_name = hs .config .server_name # type: str
127120 self ._macaroon_secret_key = hs .config .macaroon_secret_key
128- self ._error_template = hs .config .sso_error_template
129121
130122 # identifier for the external_ids table
131123 self ._auth_provider_id = "oidc"
132124
133- def _render_error (
134- self , request , error : str , error_description : Optional [str ] = None
135- ) -> None :
136- """Render the error template and respond to the request with it.
137-
138- This is used to show errors to the user. The template of this page can
139- be found under `synapse/res/templates/sso_error.html`.
140-
141- Args:
142- request: The incoming request from the browser.
143- We'll respond with an HTML page describing the error.
144- error: A technical identifier for this error. Those include
145- well-known OAuth2/OIDC error types like invalid_request or
146- access_denied.
147- error_description: A human-readable description of the error.
148- """
149- html = self ._error_template .render (
150- error = error , error_description = error_description
151- )
152- respond_with_html (request , 400 , html )
125+ self ._sso_handler = hs .get_sso_handler ()
153126
154127 def _validate_metadata (self ):
155128 """Verifies the provider metadata.
@@ -571,7 +544,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
571544
572545 Since we might want to display OIDC-related errors in a user-friendly
573546 way, we don't raise SynapseError from here. Instead, we call
574- ``self._render_error `` which displays an HTML page for the error.
547+ ``self._sso_handler.render_error `` which displays an HTML page for the error.
575548
576549 Most of the OpenID Connect logic happens here:
577550
@@ -609,7 +582,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
609582 if error != "access_denied" :
610583 logger .error ("Error from the OIDC provider: %s %s" , error , description )
611584
612- self ._render_error (request , error , description )
585+ self ._sso_handler . render_error (request , error , description )
613586 return
614587
615588 # otherwise, it is presumably a successful response. see:
@@ -619,7 +592,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
619592 session = request .getCookie (SESSION_COOKIE_NAME ) # type: Optional[bytes]
620593 if session is None :
621594 logger .info ("No session cookie found" )
622- self ._render_error (request , "missing_session" , "No session cookie found" )
595+ self ._sso_handler .render_error (
596+ request , "missing_session" , "No session cookie found"
597+ )
623598 return
624599
625600 # Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +612,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
637612 # Check for the state query parameter
638613 if b"state" not in request .args :
639614 logger .info ("State parameter is missing" )
640- self ._render_error (request , "invalid_request" , "State parameter is missing" )
615+ self ._sso_handler .render_error (
616+ request , "invalid_request" , "State parameter is missing"
617+ )
641618 return
642619
643620 state = request .args [b"state" ][0 ].decode ()
@@ -651,17 +628,19 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
651628 ) = self ._verify_oidc_session_token (session , state )
652629 except MacaroonDeserializationException as e :
653630 logger .exception ("Invalid session" )
654- self ._render_error (request , "invalid_session" , str (e ))
631+ self ._sso_handler . render_error (request , "invalid_session" , str (e ))
655632 return
656633 except MacaroonInvalidSignatureException as e :
657634 logger .exception ("Could not verify session" )
658- self ._render_error (request , "mismatching_session" , str (e ))
635+ self ._sso_handler . render_error (request , "mismatching_session" , str (e ))
659636 return
660637
661638 # Exchange the code with the provider
662639 if b"code" not in request .args :
663640 logger .info ("Code parameter is missing" )
664- self ._render_error (request , "invalid_request" , "Code parameter is missing" )
641+ self ._sso_handler .render_error (
642+ request , "invalid_request" , "Code parameter is missing"
643+ )
665644 return
666645
667646 logger .debug ("Exchanging code" )
@@ -670,7 +649,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
670649 token = await self ._exchange_code (code )
671650 except OidcError as e :
672651 logger .exception ("Could not exchange code" )
673- self ._render_error (request , e .error , e .error_description )
652+ self ._sso_handler . render_error (request , e .error , e .error_description )
674653 return
675654
676655 logger .debug ("Successfully obtained OAuth2 access token" )
@@ -683,15 +662,15 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
683662 userinfo = await self ._fetch_userinfo (token )
684663 except Exception as e :
685664 logger .exception ("Could not fetch userinfo" )
686- self ._render_error (request , "fetch_error" , str (e ))
665+ self ._sso_handler . render_error (request , "fetch_error" , str (e ))
687666 return
688667 else :
689668 logger .debug ("Extracting userinfo from id_token" )
690669 try :
691670 userinfo = await self ._parse_id_token (token , nonce = nonce )
692671 except Exception as e :
693672 logger .exception ("Invalid id_token" )
694- self ._render_error (request , "invalid_token" , str (e ))
673+ self ._sso_handler . render_error (request , "invalid_token" , str (e ))
695674 return
696675
697676 # Pull out the user-agent and IP from the request.
@@ -705,7 +684,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
705684 )
706685 except MappingException as e :
707686 logger .exception ("Could not map user" )
708- self ._render_error (request , "mapping_error" , str (e ))
687+ self ._sso_handler . render_error (request , "mapping_error" , str (e ))
709688 return
710689
711690 # Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +749,7 @@ def _generate_oidc_session_token(
770749 macaroon .add_first_party_caveat (
771750 "ui_auth_session_id = %s" % (ui_auth_session_id ,)
772751 )
773- now = self ._clock .time_msec ()
752+ now = self .clock .time_msec ()
774753 expiry = now + duration_in_ms
775754 macaroon .add_first_party_caveat ("time < %d" % (expiry ,))
776755
@@ -845,7 +824,7 @@ def _verify_expiry(self, caveat: str) -> bool:
845824 if not caveat .startswith (prefix ):
846825 return False
847826 expiry = int (caveat [len (prefix ) :])
848- now = self ._clock .time_msec ()
827+ now = self .clock .time_msec ()
849828 return now < expiry
850829
851830 async def _map_userinfo_to_user (
@@ -885,20 +864,14 @@ async def _map_userinfo_to_user(
885864 # to be strings.
886865 remote_user_id = str (remote_user_id )
887866
888- logger .info (
889- "Looking for existing mapping for user %s:%s" ,
890- self ._auth_provider_id ,
891- remote_user_id ,
892- )
893-
894- registered_user_id = await self ._datastore .get_user_by_external_id (
867+ # first of all, check if we already have a mapping for this user
868+ previously_registered_user_id = await self ._sso_handler .get_sso_user_by_remote_user_id (
895869 self ._auth_provider_id , remote_user_id ,
896870 )
871+ if previously_registered_user_id :
872+ return previously_registered_user_id
897873
898- if registered_user_id is not None :
899- logger .info ("Found existing mapping %s" , registered_user_id )
900- return registered_user_id
901-
874+ # Otherwise, generate a new user.
902875 try :
903876 attributes = await self ._user_mapping_provider .map_user_attributes (
904877 userinfo , token
@@ -917,8 +890,8 @@ async def _map_userinfo_to_user(
917890
918891 localpart = map_username_to_mxid_localpart (attributes ["localpart" ])
919892
920- user_id = UserID (localpart , self ._hostname ).to_string ()
921- users = await self ._datastore .get_users_by_id_case_insensitive (user_id )
893+ user_id = UserID (localpart , self .server_name ).to_string ()
894+ users = await self .store .get_users_by_id_case_insensitive (user_id )
922895 if users :
923896 if self ._allow_existing_users :
924897 if len (users ) == 1 :
@@ -942,7 +915,8 @@ async def _map_userinfo_to_user(
942915 default_display_name = attributes ["display_name" ],
943916 user_agent_ips = (user_agent , ip_address ),
944917 )
945- await self ._datastore .record_user_external_id (
918+
919+ await self .store .record_user_external_id (
946920 self ._auth_provider_id , remote_user_id , registered_user_id ,
947921 )
948922 return registered_user_id
0 commit comments