Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions azure_auth/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class AuthHandler:
Class to interface with `msal` package and execute authentication process.
"""

GROUPS_UPDATED = False

def __init__(self, request: HttpRequest):
"""

Expand Down Expand Up @@ -131,28 +133,49 @@ def authenticate(self, token: dict) -> AbstractBaseUser:
# in AZURE_AUTH.GROUP_ATTRIBUTE
def sync_groups(self, user, token):
role_mappings = settings.AZURE_AUTH.get("ROLES")
if not role_mappings:
# No role mappings defined, nothing to do
return user
self.initialize_groups()

groups_attr = settings.AZURE_AUTH.get("GROUP_ATTRIBUTE", "roles")
azure_token_roles = token.get("id_token_claims", {}).get(groups_attr, None)
if role_mappings: # pragma: no branch
for role, group_names in role_mappings.items():
if not isinstance(group_names, list):
group_names = [group_names]
for group_name in group_names:
if not group_name:
continue # Skip empty group names
# all groups are created by default if they not exist
django_group = Group.objects.get_or_create(name=group_name)[0]

if azure_token_roles and role in azure_token_roles:
# Add user with permissions to the corresponding django group
user.groups.add(django_group)
else:
# No permission so check if user is in group and remove
if user.groups.filter(name=group_name).exists():
user.groups.remove(django_group)

token_groups = set()
all_groups = set()
for role_id, name_or_names in role_mappings.items():
if not isinstance(name_or_names, list):
name_or_names = [str(name_or_names)]
all_groups.update(name_or_names)
if role_id in azure_token_roles:
token_groups.update(name_or_names)
current_groups = list(user.groups.values_list("name", flat=True))

if to_add := [item for item in token_groups if item not in current_groups]:
user.groups.add(*Group.objects.filter(name__in=to_add))

if to_remove := [
item
for item in current_groups
if item in all_groups and item not in token_groups
]:
user.groups.remove(*Group.objects.filter(name__in=to_remove))
return user

def initialize_groups(self):
if not AuthHandler.GROUPS_UPDATED:
role_mappings: dict = settings.AZURE_AUTH.get("ROLES")
all_groups = set()
for group_names in role_mappings.values():
if not isinstance(group_names, list):
group_names = [group_names]
all_groups.update(group_names)
if all_groups:
for group_name in all_groups:
if group_name:
Group.objects.get_or_create(name=group_name)
AuthHandler.GROUPS_UPDATED = True

def get_logout_uri(self) -> str:
"""
Forms the URI to log the user out in the Active Directory app and
Expand Down
6 changes: 6 additions & 0 deletions azure_auth/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from django.contrib.auth import get_user_model
from mixer.backend.django import mixer
from azure_auth.handlers import AuthHandler

UserModel = get_user_model()

Expand All @@ -19,6 +20,11 @@ def user(request):
return _user


@pytest.fixture(scope="function", autouse=True)
def clean_groups():
AuthHandler.GROUPS_UPDATED = False


@pytest.fixture(scope="function")
def auth_flow(request):
_auth_flow = {
Expand Down
21 changes: 21 additions & 0 deletions azure_auth/tests/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,24 @@ def test_empty_group_list_mapping(self):
handler = self._build_auth_handler()
handler.sync_groups(self.user, self.token)
self.assertEqual(self.user.groups.count(), 0)

@override_settings(
AZURE_AUTH=ChainMap(
{
"ROLES": {
"95170e67-2bbf-4e3e-a4d7-e7e5829fe7a7": [
"GroupName1",
"GroupName2",
"GroupName3",
],
"cfa8556d-dd93-420c-abd2-477ba336d2d6": "GroupName1",
"e236600b-7062-4bff-8ccf-12d8fbd80e5f": "GroupName1",
}
},
settings.AZURE_AUTH,
)
)
def test_multiple_roles_to_one_group(self):
handler = self._build_auth_handler()
handler.sync_groups(self.user, self.token)
self.assertEqual(self.user.groups.count(), 3)