Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add some tests for the IDP picker flow
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Jan 7, 2021
1 parent bbd0444 commit 8a910f9
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 5 deletions.
4 changes: 2 additions & 2 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ def __init__(self, hs: "HomeServer"):
# register themselves with the main SSOHandler.
if hs.config.cas_enabled:
hs.get_cas_handler()
elif hs.config.saml2_enabled:
if hs.config.saml2_enabled:
hs.get_saml_handler()
elif hs.config.oidc_enabled:
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()

Expand Down
191 changes: 189 additions & 2 deletions tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@

import time
import urllib.parse
from typing import Any, Dict, Union
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from mock import Mock

import pymacaroons

from twisted.web.resource import Resource

import synapse.rest.admin
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.rest.synapse.client.pick_idp import PickIdpResource

from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.unittest import override_config, skip_unless

try:
Expand Down Expand Up @@ -350,6 +359,184 @@ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
self.assertEquals(channel.result["code"], b"200", channel.result)


@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
"""Tests for homeservers with multiple SSO providers enabled"""

servlets = [
login.register_servlets,
]

def default_config(self) -> Dict[str, Any]:
config = super().default_config()

config["public_baseurl"] = BASE_URL

config["cas_config"] = {
"enabled": True,
"server_url": CAS_SERVER,
"service_url": "https://matrix.goodserver.com:8448",
}

config["saml2_config"] = {
"sp_config": {
"metadata": {"inline": [TEST_SAML_METADATA]},
# use the XMLSecurity backend to avoid relying on xmlsec1
"crypto_backend": "XMLSecurity",
},
}

config["oidc_config"] = TEST_OIDC_CONFIG

return config

def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
return d

def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
client_redirect_url = "https://x?<abc>"

# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]

# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
self.assertEqual(channel.code, 200, channel.result)

# parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser):
def __init__(self):
super().__init__()

# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]

# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]

def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]

def error(_, message):
self.fail(message)

p = FormPageParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()

self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])

self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)

def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
client_redirect_url = "https://x?<abc>"

channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
cas_uri = channel.headers.getRawHeaders("Location")[0]
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)

# it should redirect us to the login page of the cas server
self.assertEqual(cas_uri_path, CAS_SERVER + "/login")

# check that the redirectUrl is correctly encoded in the service param - ie, the
# place that CAS will redirect to
cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)

def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server"""
client_redirect_url = "https://x?<abc>"

channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
saml_uri = channel.headers.getRawHeaders("Location")[0]
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)

# it should redirect us to the login page of the SAML server
self.assertEqual(saml_uri_path, SAML_SERVER)

# the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url)

def test_multi_sso_redirect_to_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"

channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)

# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)

# ... and should have set a cookie including the redirect url
cookies = dict(
h.split(";")[0].split("=", maxsplit=1)
for h in channel.headers.getRawHeaders("Set-Cookie")
)

oidc_session_cookie = cookies["oidc_session"]
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"),
client_redirect_url,
)

def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
"GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)

@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))


class CASTestCase(unittest.HomeserverTestCase):

servlets = [
Expand All @@ -363,7 +550,7 @@ def make_homeserver(self, reactor, clock):
config = self.default_config()
config["cas_config"] = {
"enabled": True,
"server_url": "https://fake.test",
"server_url": CAS_SERVER,
"service_url": "https://matrix.goodserver.com:8448",
}

Expand Down
3 changes: 2 additions & 1 deletion tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,15 @@ async def mock_req(method: str, uri: str, data=None, headers=None):


# an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
"issuer": "https://issuer.test",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://z",
"authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo",
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
Expand Down

0 comments on commit 8a910f9

Please sign in to comment.