diff --git a/tiled/_tests/test_authenticators.py b/tiled/_tests/test_authenticators.py index af2f98568..ea00100f6 100644 --- a/tiled/_tests/test_authenticators.py +++ b/tiled/_tests/test_authenticators.py @@ -5,8 +5,21 @@ from ..authenticators import LDAPAuthenticator +# fmt: off +@pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ + ("localhost", 1389), + ("localhost:1389", 904), # Random port, ignored + ("localhost:1389", None), + ("127.0.0.1", 1389), + ("127.0.0.1:1389", 904), + (["localhost"], 1389), + (["localhost", "127.0.0.1"], 1389), + (["localhost", "127.0.0.1:1389"], 1389), + (["localhost:1389", "127.0.0.1:1389"], None), +]) +# fmt: on @pytest.mark.parametrize("use_tls,use_ssl", [(False, False)]) -def test_LDAPAuthenticator_01(use_tls, use_ssl): +def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server_port): """ Basic test for ``LDAPAuthenticator``. @@ -15,8 +28,8 @@ def test_LDAPAuthenticator_01(use_tls, use_ssl): """ pytest.importorskip("ldap3") authenticator = LDAPAuthenticator( - "localhost", - 1389, + ldap_server_address, + ldap_server_port, bind_dn_template="cn={username},ou=users,dc=example,dc=org", use_tls=use_tls, use_ssl=use_ssl, diff --git a/tiled/authenticators.py b/tiled/authenticators.py index 3151f44dd..4233e6299 100644 --- a/tiled/authenticators.py +++ b/tiled/authenticators.py @@ -1,7 +1,9 @@ import asyncio +import functools import logging import re import secrets +from collections.abc import Iterable from fastapi import APIRouter, Request from jose import JWTError, jwk, jwt @@ -347,8 +349,11 @@ class LDAPAuthenticator: Parameters ---------- - server_address: str - Address of the LDAP server to contact. + server_address: str or list(str) + Address(es) of the LDAP server(s) to contact. A string value may represent a single + server, a list of strings may represent one or more servers. If a server address + includes port, then the value of ``server_port`` is ignored, otherwise ``server_port`` + or the default port is used to access the server. Could be an IP address or hostname. server_port: int or None @@ -363,6 +368,14 @@ class LDAPAuthenticator: Enable/disable TLS if ``use_ssl`` is False. By default TLS is enabled. It should not be disabled in production systems. + connect_timeout: float + Timeout used for connecting to the LDAP server. Default: 5. + + receive_timeout: float + Timeout used for communication with the LDAP server, e.g. this timeout is used to wait for + completion of 2FA. For smooth operation it should probably exceed timeout set at LDAP server. + Default: 60. + bind_dn_template: list or str Template from which to construct the full dn when authenticating to LDAP. ``{username}`` is replaced @@ -519,6 +532,8 @@ def __init__( *, use_ssl=False, use_tls=True, + connect_timeout=5, + receive_timeout=60, bind_dn_template=None, allowed_groups=None, valid_username_regex=r"^[a-z][.a-z0-9_-]*$", @@ -535,12 +550,10 @@ def __init__( auth_state_attributes=None, use_lookup_dn_username=True, ): - if not modules_available("ldap3"): - raise ModuleNotFoundError( - "This LDAPAuthenticator requires the module 'ldap3' to be installed." - ) self.use_ssl = use_ssl self.use_tls = use_tls + self.connect_timeout = connect_timeout + self.receive_timeout = receive_timeout self.bind_dn_template = bind_dn_template self.allowed_groups = allowed_groups self.valid_username_regex = valid_username_regex @@ -559,7 +572,21 @@ def __init__( ) self.use_lookup_dn_username = use_lookup_dn_username - self.server_address = server_address + if isinstance(server_address, str): + server_address_list = [server_address] + elif isinstance(server_address, Iterable): + server_address_list = list(server_address) + else: + raise TypeError( + f"Unsupported type of `server_address` (list): server_address={server_address} " + f"type(server_address)={type(server_address)}" + ) + if not server_address_list: + raise ValueError( + "No servers are specified: 'server_address' is an empty list" + ) + + self.server_address_list = server_address_list self.server_port = ( server_port if server_port is not None else self._server_port_default() ) @@ -571,8 +598,8 @@ def _server_port_default(self): return 389 # default plaintext port for LDAP async def resolve_username(self, username_supplied_by_user): + import ldap3 - import ldap3.utils.conv search_dn = self.lookup_dn_search_user if self.escape_userdn: @@ -604,12 +631,16 @@ async def resolve_username(self, username_supplied_by_user): attributes=self.user_attribute, ) ) - conn.search( + + search_func = functools.partial( + conn.search, search_base=self.user_search_base, search_scope=ldap3.SUBTREE, search_filter=search_filter, attributes=[self.lookup_dn_user_dn_attribute], ) + await asyncio.get_running_loop().run_in_executor(None, search_func) + response = conn.response if len(response) == 0 or "attributes" not in response[0].keys(): msg = ( @@ -649,33 +680,66 @@ async def resolve_username(self, username_supplied_by_user): return (user_dn, response[0]["dn"]) def get_connection(self, userdn, password): + import ldap3 - server = ldap3.Server( - self.server_address, port=self.server_port, use_ssl=self.use_ssl - ) + # NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool. + # It probably does not matter if the pool contains only one server, but it could have implications + # when there are multiple servers in the pool. It is not clear what those implications are. + # But using the default 'activate=True' results in the thread being blocked indefinitely + # at the step of creating 'ldap3.Connection' regardless of timeouts in case all the servers are + # inactive (e.g. the pool has one server and it is unaccessible), which is unacceptable. + # Further investigation may be needed in the future. + server_pool = ldap3.ServerPool(None, ldap3.RANDOM, active=False) + for address in self.server_address_list: + if re.search(r".+:\d+", address): + # Port is found in the address + address_split = address.split(":") + server_addr = ":".join(address_split[:-1]) + server_port = int(address_split[-1]) + else: + # Use the default port + server_addr = address + server_port = self.server_port + + server = ldap3.Server( + server_addr, + port=server_port, + use_ssl=self.use_ssl, + connect_timeout=self.connect_timeout, + ) + server_pool.add(server) + auto_bind_no_ssl = ( ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS ) auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( - server, user=userdn, password=password, auto_bind=auto_bind + server_pool, + user=userdn, + password=password, + auto_bind=auto_bind, + receive_timeout=self.receive_timeout, ) return conn - def get_user_attributes(self, conn, userdn): + async def get_user_attributes(self, conn, userdn): attrs = {} if self.auth_state_attributes: - found = conn.search( - userdn, "(objectClass=*)", attributes=self.auth_state_attributes + search_func = functools.partial( + conn.search, + userdn, + "(objectClass=*)", + attributes=self.auth_state_attributes, ) + found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: attrs = conn.entries[0].entry_attributes_as_dict return attrs async def authenticate(self, username: str, password: str): + import ldap3 - import ldap3.utils.conv username_saved = username # Save the user name passed as a parameter @@ -737,13 +801,13 @@ async def authenticate(self, username: str, password: str): exc_msg=exc.args[0] if exc.args else "", ) else: - is_bound = ( - True - if conn.bound - else await asyncio.get_running_loop().run_in_executor( + if conn.bound: + is_bound = True + else: + is_bound = await asyncio.get_running_loop().run_in_executor( None, conn.bind ) - ) + msg = msg.format(username=username, userdn=userdn, is_bound=is_bound) logger.debug(msg) if is_bound: @@ -758,12 +822,16 @@ async def authenticate(self, username: str, password: str): search_filter = self.search_filter.format( userattr=self.user_attribute, username=username ) - conn.search( + + search_func = functools.partial( + conn.search, search_base=self.user_search_base, search_scope=ldap3.SUBTREE, search_filter=search_filter, attributes=self.attributes, ) + await asyncio.get_running_loop().run_in_executor(None, search_func) + n_users = len(conn.response) if n_users == 0: msg = "User with '{userattr}={username}' not found in directory" @@ -796,14 +864,20 @@ async def authenticate(self, username: str, password: str): ) group_filter = group_filter.format(userdn=userdn, uid=username) group_attributes = ["member", "uniqueMember", "memberUid"] - found = conn.search( + + search_func = functools.partial( + conn.search, group, search_scope=ldap3.BASE, search_filter=group_filter, attributes=group_attributes, ) + found = await asyncio.get_running_loop().run_in_executor( + None, search_func + ) if found: break + if not found: # If we reach here, then none of the groups matched msg = "username:{username} User not in any of the allowed groups" @@ -813,7 +887,7 @@ async def authenticate(self, username: str, password: str): if not self.use_lookup_dn_username: username = username_saved - user_info = self.get_user_attributes(conn, userdn) + user_info = await self.get_user_attributes(conn, userdn) if user_info: logger.debug("username:%s attributes:%s", username, user_info) return {"name": username, "auth_state": user_info}