Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor the CAS handler #8958

Merged
merged 7 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8958.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Properly store the mapping of external ID to Matrix ID for CAS users.
211 changes: 147 additions & 64 deletions synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import urllib.parse
from typing import TYPE_CHECKING, Dict, Optional
from xml.etree import ElementTree as ET

import attr

from twisted.web.client import PartialDownloadError

from synapse.api.errors import Codes, LoginError
from synapse.api.errors import HttpResponseException
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart

Expand All @@ -29,6 +31,26 @@
logger = logging.getLogger(__name__)


class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.
"""

def __init__(self, error, error_description=None):
self.error = error
self.error_description = error_description

def __str__(self):
if self.error_description:
return "{}: {}".format(self.error, self.error_description)
return self.error


@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, Optional[str]])


class CasHandler:
"""
Utility class for to handle the response from a CAS SSO service.
Expand All @@ -50,6 +72,8 @@ def __init__(self, hs: "HomeServer"):

self._http_client = hs.get_proxied_http_client()

self._sso_handler = hs.get_sso_handler()

def _build_service_param(self, args: Dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
Expand All @@ -69,14 +93,20 @@ def _build_service_param(self, args: Dict[str, str]) -> str:

async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
) -> Tuple[str, Optional[str]]:
) -> CasResponse:
"""
Validate a CAS ticket with the server, parse the response, and return the user and display name.
Validate a CAS ticket with the server, and return the parsed the response.

Args:
ticket: The CAS ticket from the client.
service_args: Additional arguments to include in the service URL.
Should be the same as those passed to `get_redirect_url`.

Raises:
CasError: If there's an error parsing the CAS response.

Returns:
The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
args = {
Expand All @@ -89,66 +119,65 @@ async def _validate_ticket(
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
raise CasError("server_error", description)
clokep marked this conversation as resolved.
Show resolved Hide resolved

user, attributes = self._parse_cas_response(body)
displayname = attributes.pop(self._cas_displayname_attribute, None)

for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)

return user, displayname
return self._parse_cas_response(body)

def _parse_cas_response(
self, cas_response_body: bytes
) -> Tuple[str, Dict[str, Optional[str]]]:
def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
"""
Retrieve the user and other parameters from the CAS response.

Args:
cas_response_body: The response from the CAS query.

Raises:
CasError: If there's an error parsing the CAS response.

Returns:
A tuple of the user and a mapping of other attributes.
The parsed CAS response.
"""

# Ensure the response is valid.
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise CasError(
"missing_service_response",
"root of CAS response is not serviceResponse",
)

success = root[0].tag.endswith("authenticationSuccess")
if not success:
raise CasError("unsucessful_response", "Unsuccessful CAS response")

# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {}
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text

# Ensure a user was found.
if user is None:
raise CasError("no_user", "CAS response does not contain user")

return CasResponse(user, attributes)

def get_redirect_url(self, service_args: Dict[str, str]) -> str:
"""
Expand Down Expand Up @@ -201,15 +230,72 @@ async def handle_ticket(
args["redirectUrl"] = client_redirect_url
if session:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)

try:
cas_response = await self._validate_ticket(ticket, args)
except CasError as e:
logger.exception("Could not validate ticket")
self._sso_handler.render_error(request, e.error, e.error_description)
return

await self._handle_cas_response(
request, cas_response, client_redirect_url, session
)

async def _handle_cas_response(
self,
request: SynapseRequest,
cas_response: CasResponse,
client_redirect_url: Optional[str],
session: Optional[str],
) -> None:
"""Handle a CAS response to a ticket request.

Assumes that the response has been validated. Maps the user onto an MXID,
registering them if necessary, and returns a response to the browser.

Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect

cas_response: The parsed CAS response.

client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.

session: The session parameter from the `/cas/ticket` HTTP request, if given.
This should be the UI Auth session id.
"""

# Ensure that the attributes of the logged in user meet the required
# attributes.
for required_attribute, required_value in self._cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in cas_response.attributes:
self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
clokep marked this conversation as resolved.
Show resolved Hide resolved
return

# Also need to check value
if required_value is not None:
actual_value = cas_response.attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
self._sso_handler.render_error(
request,
"unauthorised",
"You are not authorised to log in here.",
)
return

# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)

# Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address
cas_response, user_agent, ip_address
)

if session:
Expand All @@ -225,34 +311,31 @@ async def handle_ticket(
)

async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
self, cas_response: CasResponse, user_agent: str, ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.

Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
cas_response: The parsed CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.

Returns:
The user ID associated with this response.
"""

localpart = map_username_to_mxid_localpart(remote_user_id)
localpart = map_username_to_mxid_localpart(cas_response.username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)

displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)

# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=display_name,
default_display_name=displayname,
user_agent_ips=[(user_agent, ip_address)],
)

Expand Down