Skip to content

Commit

Permalink
Recent changes to the implementation of LDAPAuthenticator authenticat…
Browse files Browse the repository at this point in the history
…or in 'bluesky-httpserver' (bluesky#308)

* ENH: copy changes to LDAPAuthenticator from bluesky-httpserver

* TST: unit test for LDAPAuthenticator

* ENH: code formatting

* Satisfy isort.

Co-authored-by: Dan Allan <dallan@bnl.gov>
  • Loading branch information
dmgav and danielballan committed Sep 7, 2022
1 parent 35fe710 commit 2abe50c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 28 deletions.
19 changes: 16 additions & 3 deletions tiled/_tests/test_authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,21 @@
from ..authenticators import LDAPAuthenticator


# fmt: off
@pytest.mark.parametrize("ldap_server_address, ldap_server_port", [
("localhost", 1389),
("localhost:1389", 904), # Random port, ignored
("localhost:1389", None),
("127.0.0.1", 1389),
("127.0.0.1:1389", 904),
(["localhost"], 1389),
(["localhost", "127.0.0.1"], 1389),
(["localhost", "127.0.0.1:1389"], 1389),
(["localhost:1389", "127.0.0.1:1389"], None),
])
# fmt: on
@pytest.mark.parametrize("use_tls,use_ssl", [(False, False)])
def test_LDAPAuthenticator_01(use_tls, use_ssl):
def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server_port):
"""
Basic test for ``LDAPAuthenticator``.
Expand All @@ -15,8 +28,8 @@ def test_LDAPAuthenticator_01(use_tls, use_ssl):
"""
pytest.importorskip("ldap3")
authenticator = LDAPAuthenticator(
"localhost",
1389,
ldap_server_address,
ldap_server_port,
bind_dn_template="cn={username},ou=users,dc=example,dc=org",
use_tls=use_tls,
use_ssl=use_ssl,
Expand Down
124 changes: 99 additions & 25 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import functools
import logging
import re
import secrets
from collections.abc import Iterable

from fastapi import APIRouter, Request
from jose import JWTError, jwk, jwt
Expand Down Expand Up @@ -347,8 +349,11 @@ class LDAPAuthenticator:
Parameters
----------
server_address: str
Address of the LDAP server to contact.
server_address: str or list(str)
Address(es) of the LDAP server(s) to contact. A string value may represent a single
server, a list of strings may represent one or more servers. If a server address
includes port, then the value of ``server_port`` is ignored, otherwise ``server_port``
or the default port is used to access the server.
Could be an IP address or hostname.
server_port: int or None
Expand All @@ -363,6 +368,14 @@ class LDAPAuthenticator:
Enable/disable TLS if ``use_ssl`` is False. By default TLS is enabled. It should not be disabled
in production systems.
connect_timeout: float
Timeout used for connecting to the LDAP server. Default: 5.
receive_timeout: float
Timeout used for communication with the LDAP server, e.g. this timeout is used to wait for
completion of 2FA. For smooth operation it should probably exceed timeout set at LDAP server.
Default: 60.
bind_dn_template: list or str
Template from which to construct the full dn
when authenticating to LDAP. ``{username}`` is replaced
Expand Down Expand Up @@ -519,6 +532,8 @@ def __init__(
*,
use_ssl=False,
use_tls=True,
connect_timeout=5,
receive_timeout=60,
bind_dn_template=None,
allowed_groups=None,
valid_username_regex=r"^[a-z][.a-z0-9_-]*$",
Expand All @@ -535,12 +550,10 @@ def __init__(
auth_state_attributes=None,
use_lookup_dn_username=True,
):
if not modules_available("ldap3"):
raise ModuleNotFoundError(
"This LDAPAuthenticator requires the module 'ldap3' to be installed."
)
self.use_ssl = use_ssl
self.use_tls = use_tls
self.connect_timeout = connect_timeout
self.receive_timeout = receive_timeout
self.bind_dn_template = bind_dn_template
self.allowed_groups = allowed_groups
self.valid_username_regex = valid_username_regex
Expand All @@ -559,7 +572,21 @@ def __init__(
)
self.use_lookup_dn_username = use_lookup_dn_username

self.server_address = server_address
if isinstance(server_address, str):
server_address_list = [server_address]
elif isinstance(server_address, Iterable):
server_address_list = list(server_address)
else:
raise TypeError(
f"Unsupported type of `server_address` (list): server_address={server_address} "
f"type(server_address)={type(server_address)}"
)
if not server_address_list:
raise ValueError(
"No servers are specified: 'server_address' is an empty list"
)

self.server_address_list = server_address_list
self.server_port = (
server_port if server_port is not None else self._server_port_default()
)
Expand All @@ -571,8 +598,8 @@ def _server_port_default(self):
return 389 # default plaintext port for LDAP

async def resolve_username(self, username_supplied_by_user):

import ldap3
import ldap3.utils.conv

search_dn = self.lookup_dn_search_user
if self.escape_userdn:
Expand Down Expand Up @@ -604,12 +631,16 @@ async def resolve_username(self, username_supplied_by_user):
attributes=self.user_attribute,
)
)
conn.search(

search_func = functools.partial(
conn.search,
search_base=self.user_search_base,
search_scope=ldap3.SUBTREE,
search_filter=search_filter,
attributes=[self.lookup_dn_user_dn_attribute],
)
await asyncio.get_running_loop().run_in_executor(None, search_func)

response = conn.response
if len(response) == 0 or "attributes" not in response[0].keys():
msg = (
Expand Down Expand Up @@ -649,33 +680,66 @@ async def resolve_username(self, username_supplied_by_user):
return (user_dn, response[0]["dn"])

def get_connection(self, userdn, password):

import ldap3

server = ldap3.Server(
self.server_address, port=self.server_port, use_ssl=self.use_ssl
)
# NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool.
# It probably does not matter if the pool contains only one server, but it could have implications
# when there are multiple servers in the pool. It is not clear what those implications are.
# But using the default 'activate=True' results in the thread being blocked indefinitely
# at the step of creating 'ldap3.Connection' regardless of timeouts in case all the servers are
# inactive (e.g. the pool has one server and it is unaccessible), which is unacceptable.
# Further investigation may be needed in the future.
server_pool = ldap3.ServerPool(None, ldap3.RANDOM, active=False)
for address in self.server_address_list:
if re.search(r".+:\d+", address):
# Port is found in the address
address_split = address.split(":")
server_addr = ":".join(address_split[:-1])
server_port = int(address_split[-1])
else:
# Use the default port
server_addr = address
server_port = self.server_port

server = ldap3.Server(
server_addr,
port=server_port,
use_ssl=self.use_ssl,
connect_timeout=self.connect_timeout,
)
server_pool.add(server)

auto_bind_no_ssl = (
ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS
)
auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl
conn = ldap3.Connection(
server, user=userdn, password=password, auto_bind=auto_bind
server_pool,
user=userdn,
password=password,
auto_bind=auto_bind,
receive_timeout=self.receive_timeout,
)
return conn

def get_user_attributes(self, conn, userdn):
async def get_user_attributes(self, conn, userdn):
attrs = {}
if self.auth_state_attributes:
found = conn.search(
userdn, "(objectClass=*)", attributes=self.auth_state_attributes
search_func = functools.partial(
conn.search,
userdn,
"(objectClass=*)",
attributes=self.auth_state_attributes,
)
found = await asyncio.get_running_loop().run_in_executor(None, search_func)
if found:
attrs = conn.entries[0].entry_attributes_as_dict
return attrs

async def authenticate(self, username: str, password: str):

import ldap3
import ldap3.utils.conv

username_saved = username # Save the user name passed as a parameter

Expand Down Expand Up @@ -737,13 +801,13 @@ async def authenticate(self, username: str, password: str):
exc_msg=exc.args[0] if exc.args else "",
)
else:
is_bound = (
True
if conn.bound
else await asyncio.get_running_loop().run_in_executor(
if conn.bound:
is_bound = True
else:
is_bound = await asyncio.get_running_loop().run_in_executor(
None, conn.bind
)
)

msg = msg.format(username=username, userdn=userdn, is_bound=is_bound)
logger.debug(msg)
if is_bound:
Expand All @@ -758,12 +822,16 @@ async def authenticate(self, username: str, password: str):
search_filter = self.search_filter.format(
userattr=self.user_attribute, username=username
)
conn.search(

search_func = functools.partial(
conn.search,
search_base=self.user_search_base,
search_scope=ldap3.SUBTREE,
search_filter=search_filter,
attributes=self.attributes,
)
await asyncio.get_running_loop().run_in_executor(None, search_func)

n_users = len(conn.response)
if n_users == 0:
msg = "User with '{userattr}={username}' not found in directory"
Expand Down Expand Up @@ -796,14 +864,20 @@ async def authenticate(self, username: str, password: str):
)
group_filter = group_filter.format(userdn=userdn, uid=username)
group_attributes = ["member", "uniqueMember", "memberUid"]
found = conn.search(

search_func = functools.partial(
conn.search,
group,
search_scope=ldap3.BASE,
search_filter=group_filter,
attributes=group_attributes,
)
found = await asyncio.get_running_loop().run_in_executor(
None, search_func
)
if found:
break

if not found:
# If we reach here, then none of the groups matched
msg = "username:{username} User not in any of the allowed groups"
Expand All @@ -813,7 +887,7 @@ async def authenticate(self, username: str, password: str):
if not self.use_lookup_dn_username:
username = username_saved

user_info = self.get_user_attributes(conn, userdn)
user_info = await self.get_user_attributes(conn, userdn)
if user_info:
logger.debug("username:%s attributes:%s", username, user_info)
return {"name": username, "auth_state": user_info}
Expand Down

0 comments on commit 2abe50c

Please sign in to comment.