Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Push login completion down into SsoHandler (#8941)
Browse files Browse the repository at this point in the history
This is another part of my work towards fixing #8876. It moves some of the logic currently in the SAML and OIDC handlers - in particular the call to `AuthHandler.complete_sso_login` down into the `SsoHandler`.
  • Loading branch information
richvdh authored Dec 16, 2020
1 parent 44b7d4c commit e1b8e37
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 86 deletions.
1 change: 1 addition & 0 deletions changelog.d/8941.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.
62 changes: 27 additions & 35 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def __init__(self, hs: "HomeServer"):
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool

self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key

Expand Down Expand Up @@ -689,33 +687,14 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:

# otherwise, it's a login

# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)

# Call the mapper to register/login the user
try:
user_id = await self._map_userinfo_to_user(
userinfo, token, user_agent, ip_address
await self._complete_oidc_login(
userinfo, token, request, client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
return

# Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)

# and finally complete the login
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_attributes
)

def _generate_oidc_session_token(
self,
Expand Down Expand Up @@ -838,10 +817,14 @@ def _verify_expiry(self, caveat: str) -> bool:
now = self.clock.time_msec()
return now < expiry

async def _map_userinfo_to_user(
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
) -> str:
"""Maps a UserInfo object to a mxid.
async def _complete_oidc_login(
self,
userinfo: UserInfo,
token: Token,
request: SynapseRequest,
client_redirect_url: str,
) -> None:
"""Given a UserInfo response, complete the login flow
UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
Expand All @@ -853,17 +836,16 @@ async def _map_userinfo_to_user(
If a user already exists with the mxid we've mapped and allow_existing_users
is disabled, raise an exception.
Otherwise, render a redirect back to the client_redirect_url with a loginToken.
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
Raises:
MappingException: if there was an error while mapping some properties
Returns:
The mxid of the user
"""
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
Expand Down Expand Up @@ -931,13 +913,23 @@ async def grandfather_existing_users() -> Optional[str]:

return None

return await self._sso_handler.get_mxid_from_sso(
# Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)

await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
request,
client_redirect_url,
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
)

def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
Expand Down
37 changes: 12 additions & 25 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()

self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
Expand Down Expand Up @@ -229,40 +227,29 @@ async def _handle_authn_response(
)
return

# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)

# Call the mapper to register/login the user
try:
user_id = await self._map_saml_response_to_user(
saml2_auth, relay_state, user_agent, ip_address
)
await self._complete_saml_login(saml2_auth, request, relay_state)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
return

await self._auth_handler.complete_sso_login(user_id, request, relay_state)

async def _map_saml_response_to_user(
async def _complete_saml_login(
self,
saml2_auth: saml2.response.AuthnResponse,
request: SynapseRequest,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> str:
) -> None:
"""
Given a SAML response, retrieve the user ID for it and possibly register the user.
Given a SAML response, complete the login flow
Retrieves the remote user ID, registers the user if necessary, and serves
a redirect back to the client with a login-token.
Args:
saml2_auth: The parsed SAML2 response.
request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
Expand Down Expand Up @@ -318,11 +305,11 @@ async def grandfather_existing_users() -> Optional[str]:

return None

return await self._sso_handler.get_mxid_from_sso(
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
request,
client_redirect_url,
saml_response_to_remapped_user_attributes,
grandfather_existing_users,
)
Expand Down
58 changes: 36 additions & 22 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from synapse.api.errors import RedirectException
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,15 +120,16 @@ async def get_sso_user_by_remote_user_id(
# No match.
return None

async def get_mxid_from_sso(
async def complete_sso_login_request(
self,
auth_provider_id: str,
remote_user_id: str,
user_agent: str,
ip_address: str,
request: SynapseRequest,
client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str:
extra_login_attributes: Optional[JsonDict] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
Expand All @@ -146,12 +148,18 @@ async def get_mxid_from_sso(
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
Finally, we generate a redirect to the supplied redirect uri, with a login token
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
Expand All @@ -163,12 +171,13 @@ async def get_mxid_from_sso(
to the user.
RedirectException to redirect to an additional page (e.g.
to prompt the user for more information).
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
Returns:
The user ID associated with the SSO response.
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
Raises:
MappingException if there was a problem mapping the response to a user.
Expand All @@ -181,28 +190,33 @@ async def get_mxid_from_sso(
# interstitial pages.
with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
if previously_registered_user_id:
return previously_registered_user_id

# Check for grandfathering of users.
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
if not user_id and grandfather_existing_users:
user_id = await grandfather_existing_users()
if user_id:
# Future logins should also match this user ID.
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
auth_provider_id, remote_user_id, user_id
)
return previously_registered_user_id

# Otherwise, generate a new user.
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
user_id = await self._register_mapped_user(
attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
)
return user_id
if not user_id:
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
user_id = await self._register_mapped_user(
attributes,
auth_provider_id,
remote_user_id,
request.get_user_agent(""),
request.getClientIP(),
)

await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_login_attributes
)

async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
Expand Down
8 changes: 4 additions & 4 deletions tests/handlers/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_map_saml_response_to_user(self):

# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri"
"@test_user:test", request, "redirect_uri", None
)

@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
Expand All @@ -157,7 +157,7 @@ def test_map_saml_response_to_existing_user(self):

# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
"@test_user:test", request, "", None
)

# Subsequent calls should map to the same mxid.
Expand All @@ -166,7 +166,7 @@ def test_map_saml_response_to_existing_user(self):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
"@test_user:test", request, "", None
)

def test_map_saml_response_to_invalid_localpart(self):
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_map_saml_response_to_user_retries(self):

# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, ""
"@test_user1:test", request, "", None
)
auth_handler.complete_sso_login.reset_mock()

Expand Down

0 comments on commit e1b8e37

Please sign in to comment.