Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit ee38202

Browse files
authored
Abstract shared SSO code. (#8765)
De-duplicates code between the SAML and OIDC implementations.
1 parent e487d9f commit ee38202

File tree

6 files changed

+159
-120
lines changed

6 files changed

+159
-120
lines changed

changelog.d/8765.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Consolidate logic between the OpenID Connect and SAML code.

synapse/handlers/oidc_handler.py

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from twisted.web.client import readBody
3535

3636
from synapse.config import ConfigError
37-
from synapse.http.server import respond_with_html
37+
from synapse.handlers._base import BaseHandler
38+
from synapse.handlers.sso import MappingException
3839
from synapse.http.site import SynapseRequest
3940
from synapse.logging.context import make_deferred_yieldable
4041
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,17 +84,12 @@ def __str__(self):
8384
return self.error
8485

8586

86-
class MappingException(Exception):
87-
"""Used to catch errors when mapping the UserInfo object
88-
"""
89-
90-
91-
class OidcHandler:
87+
class OidcHandler(BaseHandler):
9288
"""Handles requests related to the OpenID Connect login flow.
9389
"""
9490

9591
def __init__(self, hs: "HomeServer"):
96-
self.hs = hs
92+
super().__init__(hs)
9793
self._callback_url = hs.config.oidc_callback_url # type: str
9894
self._scopes = hs.config.oidc_scopes # type: List[str]
9995
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@@ -120,36 +116,13 @@ def __init__(self, hs: "HomeServer"):
120116
self._http_client = hs.get_proxied_http_client()
121117
self._auth_handler = hs.get_auth_handler()
122118
self._registration_handler = hs.get_registration_handler()
123-
self._datastore = hs.get_datastore()
124-
self._clock = hs.get_clock()
125-
self._hostname = hs.hostname # type: str
126119
self._server_name = hs.config.server_name # type: str
127120
self._macaroon_secret_key = hs.config.macaroon_secret_key
128-
self._error_template = hs.config.sso_error_template
129121

130122
# identifier for the external_ids table
131123
self._auth_provider_id = "oidc"
132124

133-
def _render_error(
134-
self, request, error: str, error_description: Optional[str] = None
135-
) -> None:
136-
"""Render the error template and respond to the request with it.
137-
138-
This is used to show errors to the user. The template of this page can
139-
be found under `synapse/res/templates/sso_error.html`.
140-
141-
Args:
142-
request: The incoming request from the browser.
143-
We'll respond with an HTML page describing the error.
144-
error: A technical identifier for this error. Those include
145-
well-known OAuth2/OIDC error types like invalid_request or
146-
access_denied.
147-
error_description: A human-readable description of the error.
148-
"""
149-
html = self._error_template.render(
150-
error=error, error_description=error_description
151-
)
152-
respond_with_html(request, 400, html)
125+
self._sso_handler = hs.get_sso_handler()
153126

154127
def _validate_metadata(self):
155128
"""Verifies the provider metadata.
@@ -571,7 +544,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
571544
572545
Since we might want to display OIDC-related errors in a user-friendly
573546
way, we don't raise SynapseError from here. Instead, we call
574-
``self._render_error`` which displays an HTML page for the error.
547+
``self._sso_handler.render_error`` which displays an HTML page for the error.
575548
576549
Most of the OpenID Connect logic happens here:
577550
@@ -609,7 +582,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
609582
if error != "access_denied":
610583
logger.error("Error from the OIDC provider: %s %s", error, description)
611584

612-
self._render_error(request, error, description)
585+
self._sso_handler.render_error(request, error, description)
613586
return
614587

615588
# otherwise, it is presumably a successful response. see:
@@ -619,7 +592,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
619592
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
620593
if session is None:
621594
logger.info("No session cookie found")
622-
self._render_error(request, "missing_session", "No session cookie found")
595+
self._sso_handler.render_error(
596+
request, "missing_session", "No session cookie found"
597+
)
623598
return
624599

625600
# Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +612,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
637612
# Check for the state query parameter
638613
if b"state" not in request.args:
639614
logger.info("State parameter is missing")
640-
self._render_error(request, "invalid_request", "State parameter is missing")
615+
self._sso_handler.render_error(
616+
request, "invalid_request", "State parameter is missing"
617+
)
641618
return
642619

643620
state = request.args[b"state"][0].decode()
@@ -651,17 +628,19 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
651628
) = self._verify_oidc_session_token(session, state)
652629
except MacaroonDeserializationException as e:
653630
logger.exception("Invalid session")
654-
self._render_error(request, "invalid_session", str(e))
631+
self._sso_handler.render_error(request, "invalid_session", str(e))
655632
return
656633
except MacaroonInvalidSignatureException as e:
657634
logger.exception("Could not verify session")
658-
self._render_error(request, "mismatching_session", str(e))
635+
self._sso_handler.render_error(request, "mismatching_session", str(e))
659636
return
660637

661638
# Exchange the code with the provider
662639
if b"code" not in request.args:
663640
logger.info("Code parameter is missing")
664-
self._render_error(request, "invalid_request", "Code parameter is missing")
641+
self._sso_handler.render_error(
642+
request, "invalid_request", "Code parameter is missing"
643+
)
665644
return
666645

667646
logger.debug("Exchanging code")
@@ -670,7 +649,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
670649
token = await self._exchange_code(code)
671650
except OidcError as e:
672651
logger.exception("Could not exchange code")
673-
self._render_error(request, e.error, e.error_description)
652+
self._sso_handler.render_error(request, e.error, e.error_description)
674653
return
675654

676655
logger.debug("Successfully obtained OAuth2 access token")
@@ -683,15 +662,15 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
683662
userinfo = await self._fetch_userinfo(token)
684663
except Exception as e:
685664
logger.exception("Could not fetch userinfo")
686-
self._render_error(request, "fetch_error", str(e))
665+
self._sso_handler.render_error(request, "fetch_error", str(e))
687666
return
688667
else:
689668
logger.debug("Extracting userinfo from id_token")
690669
try:
691670
userinfo = await self._parse_id_token(token, nonce=nonce)
692671
except Exception as e:
693672
logger.exception("Invalid id_token")
694-
self._render_error(request, "invalid_token", str(e))
673+
self._sso_handler.render_error(request, "invalid_token", str(e))
695674
return
696675

697676
# Pull out the user-agent and IP from the request.
@@ -705,7 +684,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
705684
)
706685
except MappingException as e:
707686
logger.exception("Could not map user")
708-
self._render_error(request, "mapping_error", str(e))
687+
self._sso_handler.render_error(request, "mapping_error", str(e))
709688
return
710689

711690
# Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +749,7 @@ def _generate_oidc_session_token(
770749
macaroon.add_first_party_caveat(
771750
"ui_auth_session_id = %s" % (ui_auth_session_id,)
772751
)
773-
now = self._clock.time_msec()
752+
now = self.clock.time_msec()
774753
expiry = now + duration_in_ms
775754
macaroon.add_first_party_caveat("time < %d" % (expiry,))
776755

@@ -845,7 +824,7 @@ def _verify_expiry(self, caveat: str) -> bool:
845824
if not caveat.startswith(prefix):
846825
return False
847826
expiry = int(caveat[len(prefix) :])
848-
now = self._clock.time_msec()
827+
now = self.clock.time_msec()
849828
return now < expiry
850829

851830
async def _map_userinfo_to_user(
@@ -885,20 +864,14 @@ async def _map_userinfo_to_user(
885864
# to be strings.
886865
remote_user_id = str(remote_user_id)
887866

888-
logger.info(
889-
"Looking for existing mapping for user %s:%s",
890-
self._auth_provider_id,
891-
remote_user_id,
892-
)
893-
894-
registered_user_id = await self._datastore.get_user_by_external_id(
867+
# first of all, check if we already have a mapping for this user
868+
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
895869
self._auth_provider_id, remote_user_id,
896870
)
871+
if previously_registered_user_id:
872+
return previously_registered_user_id
897873

898-
if registered_user_id is not None:
899-
logger.info("Found existing mapping %s", registered_user_id)
900-
return registered_user_id
901-
874+
# Otherwise, generate a new user.
902875
try:
903876
attributes = await self._user_mapping_provider.map_user_attributes(
904877
userinfo, token
@@ -917,8 +890,8 @@ async def _map_userinfo_to_user(
917890

918891
localpart = map_username_to_mxid_localpart(attributes["localpart"])
919892

920-
user_id = UserID(localpart, self._hostname).to_string()
921-
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
893+
user_id = UserID(localpart, self.server_name).to_string()
894+
users = await self.store.get_users_by_id_case_insensitive(user_id)
922895
if users:
923896
if self._allow_existing_users:
924897
if len(users) == 1:
@@ -942,7 +915,8 @@ async def _map_userinfo_to_user(
942915
default_display_name=attributes["display_name"],
943916
user_agent_ips=(user_agent, ip_address),
944917
)
945-
await self._datastore.record_user_external_id(
918+
919+
await self.store.record_user_external_id(
946920
self._auth_provider_id, remote_user_id, registered_user_id,
947921
)
948922
return registered_user_id

0 commit comments

Comments
 (0)