Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Abstract shared SSO code. (#8765)
Browse files Browse the repository at this point in the history
De-duplicates code between the SAML and OIDC implementations.
  • Loading branch information
clokep authored Nov 17, 2020
1 parent e487d9f commit ee38202
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 120 deletions.
1 change: 1 addition & 0 deletions changelog.d/8765.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Consolidate logic between the OpenID Connect and SAML code.
92 changes: 33 additions & 59 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from twisted.web.client import readBody

from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
Expand Down Expand Up @@ -83,17 +84,12 @@ def __str__(self):
return self.error


class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""


class OidcHandler:
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""

def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
Expand All @@ -120,36 +116,13 @@ def __init__(self, hs: "HomeServer"):
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template

# identifier for the external_ids table
self._auth_provider_id = "oidc"

def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
self._sso_handler = hs.get_sso_handler()

def _validate_metadata(self):
"""Verifies the provider metadata.
Expand Down Expand Up @@ -571,7 +544,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error.
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
Expand Down Expand Up @@ -609,7 +582,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)

self._render_error(request, error, description)
self._sso_handler.render_error(request, error, description)
return

# otherwise, it is presumably a successful response. see:
Expand All @@ -619,7 +592,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return

# Remove the cookie. There is a good chance that if the callback failed
Expand All @@ -637,7 +612,9 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return

state = request.args[b"state"][0].decode()
Expand All @@ -651,17 +628,19 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e))
self._sso_handler.render_error(request, "mismatching_session", str(e))
return

# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return

logger.debug("Exchanging code")
Expand All @@ -670,7 +649,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description)
self._sso_handler.render_error(request, e.error, e.error_description)
return

logger.debug("Successfully obtained OAuth2 access token")
Expand All @@ -683,15 +662,15 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e))
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e))
self._sso_handler.render_error(request, "invalid_token", str(e))
return

# Pull out the user-agent and IP from the request.
Expand All @@ -705,7 +684,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
self._sso_handler.render_error(request, "mapping_error", str(e))
return

# Mapping providers might not have get_extra_attributes: only call this
Expand Down Expand Up @@ -770,7 +749,7 @@ def _generate_oidc_session_token(
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))

Expand Down Expand Up @@ -845,7 +824,7 @@ def _verify_expiry(self, caveat: str) -> bool:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
now = self.clock.time_msec()
return now < expiry

async def _map_userinfo_to_user(
Expand Down Expand Up @@ -885,20 +864,14 @@ async def _map_userinfo_to_user(
# to be strings.
remote_user_id = str(remote_user_id)

logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)

registered_user_id = await self._datastore.get_user_by_external_id(
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
)
if previously_registered_user_id:
return previously_registered_user_id

if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id

# Otherwise, generate a new user.
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
Expand All @@ -917,8 +890,8 @@ async def _map_userinfo_to_user(

localpart = map_username_to_mxid_localpart(attributes["localpart"])

user_id = UserID(localpart, self._hostname).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
Expand All @@ -942,7 +915,8 @@ async def _map_userinfo_to_user(
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(

await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id
Expand Down
Loading

0 comments on commit ee38202

Please sign in to comment.