Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Preparatory refactoring of the SamlHandlerTestCase #8938

Merged
merged 3 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8938.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.
23 changes: 23 additions & 0 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ async def handle_saml_response(self, request: SynapseRequest) -> None:
return

logger.debug("SAML2 response: %s", saml2_auth.origxml)

await self._handle_authn_response(request, saml2_auth, relay_state)

async def _handle_authn_response(
self,
request: SynapseRequest,
saml2_auth: saml2.response.AuthnResponse,
relay_state: str,
) -> None:
"""Handle an AuthnResponse, having parsed it from the request params

Assumes that the signature on the response object has been checked. Maps
the user onto an MXID, registering them if necessary, and returns a response
to the browser.

Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
saml2_auth: the parsed AuthnResponse object
relay_state: the RelayState query param, which encodes the URI to rediret
back to
"""

for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
Expand Down
12 changes: 1 addition & 11 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from synapse.handlers.sso import MappingException
from synapse.types import UserID

from tests.test_utils import FakeResponse
from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config

# These are a few constants that are used as config parameters in the tests.
Expand Down Expand Up @@ -82,16 +82,6 @@ async def map_user_attributes(self, userinfo, token, failures):
}


def simple_async_mock(return_value=None, raises=None) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
if raises:
raise raises
return return_value

return Mock(side_effect=cb)


async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
Expand Down
132 changes: 89 additions & 43 deletions tests/handlers/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from mock import Mock

import attr

from synapse.api.errors import RedirectException
from synapse.handlers.sso import MappingException

from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config

# Check if we have the dependencies to run the tests.
Expand Down Expand Up @@ -44,6 +48,8 @@
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
assertions = attr.ib(type=list, factory=list)
in_response_to = attr.ib(type=Optional[str], default=None)


class TestMappingProvider:
Expand Down Expand Up @@ -111,15 +117,22 @@ def make_homeserver(self, reactor, clock):

def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()

# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
# The redirect_url doesn't matter with the default user mapping provider.
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "redirect_uri")
)

# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri"
)
self.assertEqual(mxid, "@test_user:test")

@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
Expand All @@ -129,53 +142,81 @@ def test_map_saml_response_to_existing_user(self):
store.register_user(user_id="@test_user:test", password_hash=None)
)

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()

# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "")
)

# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
)
self.assertEqual(mxid, "@test_user:test")

# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
auth_handler.complete_sso_login.reset_mock()
self.get_success(
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
)
self.assertEqual(mxid, "@test_user:test")

def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""

# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()

# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)

saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
redirect_url = ""
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, ""),
)
sso_handler.render_error.assert_called_once_with(
request, "mapping_error", "localpart is invalid: föö"
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
auth_handler.complete_sso_login.assert_not_called()

def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""

# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)

# register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)

# send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
request = _mock_request()
self.get_success(
self.handler._handle_authn_response(request, saml_response, ""),
)

# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, ""
)
auth_handler.complete_sso_login.reset_mock()

# Register all of the potential mxids for a particular SAML username.
self.get_success(
Expand All @@ -188,15 +229,15 @@ def test_map_saml_response_to_user_retries(self):

# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
self.get_success(
self.handler._handle_authn_response(request, saml_response, ""),
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
sso_handler.render_error.assert_called_once_with(
request,
"mapping_error",
"Unable to generate a Matrix ID from the SSO response",
)
auth_handler.complete_sso_login.assert_not_called()

@override_config(
{
Expand All @@ -208,12 +249,17 @@ def test_map_saml_response_to_user_retries(self):
}
)
def test_map_saml_response_redirect(self):
"""Test a mapping provider that raises a RedirectException"""

saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = ""
request = _mock_request()
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")


def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "get_user_agent"])
12 changes: 12 additions & 0 deletions tests/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar

from mock import Mock

import attr

from twisted.python.failure import Failure
Expand Down Expand Up @@ -87,6 +89,16 @@ def cleanup():
return cleanup


def simple_async_mock(return_value=None, raises=None) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
if raises:
raise raises
return return_value

return Mock(side_effect=cb)


@attr.s
class FakeResponse:
"""A fake twisted.web.IResponse object
Expand Down