Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions app/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
global_admin = "globalAdmin"
admin_project = "adminProject"
practice_lead_project = "practiceLeadProject"
member_project = "memberProject"
self_value = "self"
ADMIN_GLOBAL = "adminGlobal"
ADMIN_PROJECT = "adminProject"
PRACTICE_LEAD_PROJECT = "practiceLeadProject"
MEMBER_PROJECT = "memberProject"
FIELD_PERMISSIONS_CSV = "core/api/field_permissions.csv"
177 changes: 177 additions & 0 deletions app/core/api/access_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import csv
from pathlib import Path
from typing import Any

from constants import ADMIN_GLOBAL # Assuming you have this constant
from constants import FIELD_PERMISSIONS_CSV
from core.models import PermissionType
from core.models import UserPermission


class AccessControl:
"""A collection of static methods for validating user permissions."""

_rank_dict_cache: dict[str, int] | None = None # class-level cache
_csv_field_permissions_cache: list[dict[str, Any]] | None = None

@staticmethod
def is_admin(user) -> bool:
"""Check if a user assigned "adminGlobal" permission."""
permission_type = PermissionType.objects.filter(name=ADMIN_GLOBAL).first()
# return True
return UserPermission.objects.filter(
permission_type=permission_type, user=user
).exists()

@classmethod
def _get_rank_dict(cls) -> dict[str, int]:
"""Return a dictionary mapping permission names to their ranks.
Example: {"adminGlobal": 1, "adminProject": 2, "practiceLeadProject": 3, "memberProject": 4}.
Used in algorithm to determine most privileged permission type between two users. The higher the rank,
the more privileged the permission.
"""
if cls._rank_dict_cache is None:
permissions = PermissionType.objects.values("name", "rank")
cls._rank_dict_cache = {perm["name"]: perm["rank"] for perm in permissions}
return cls._rank_dict_cache

@classmethod
def _get_csv_field_permissions(cls) -> list[dict[str, Any]]:
"""Read the field permissions from a CSV file.

Caches the result so the CSV is read only once.
"""
if cls._csv_field_permissions_cache is None:
file_path = Path(FIELD_PERMISSIONS_CSV)
with file_path.open() as file:
reader = csv.DictReader(file)
cls._csv_field_permissions_cache = list(reader)
return cls._csv_field_permissions_cache

@classmethod
def get_permitted_fields(
cls, operation: str, permission_type: str, table_name: str
) -> list[str]:
"""
Return the list of field names accessible for a user with the given permission type
for a given operation.

Parameters:
operation (str): The type of operation. (e.g., "get", "post", "patch").
permission_type (str): The permission type of the requesting user
(e.g., "adminGlobal", "adminProject", etc.).
table_name (str): The name of the table/model
(e.g., "User", "Project").

Returns:
list[str]: A list of field names that the user with the given
permission_type can access for the specified operation on the
specified table.

Example:
>>> get_fields("get", "adminProject", "User")
["first_name", "last_name"]
"""
if not permission_type: # Early exit guard clause
return []

valid_fields = set()

for field_permission in cls._get_csv_field_permissions():
if cls.has_field_permission(
operation=operation,
requester_permission_type=permission_type,
table_name=table_name,
field=field_permission,
):
valid_fields.add(field_permission["field_name"])

return list(valid_fields)

@classmethod
def get_highest_user_perm_type(cls, requesting_user) -> str:
"""Return the most privileged permission type of a user."""

permissions = UserPermission.objects.filter(
user=requesting_user, project__name=None
).values("permission_type__name", "permission_type__rank")

if not permissions:
return ""

max_permission = max(permissions, key=lambda p: p["permission_type__rank"])
return max_permission["permission_type__name"]

@classmethod
def get_highest_shared_project_perm_type(
cls, requesting_user, response_related_user
) -> str:
"""Return the most privileged permission type between users."""
if cls.is_admin(requesting_user):
return ADMIN_GLOBAL

target_projects = UserPermission.objects.filter(
user=response_related_user
).values_list("project__name", flat=True)
target_projects = UserPermission.objects.filter(
user=response_related_user
).values_list("project__name", flat=True)

permissions = UserPermission.objects.filter(
user=requesting_user, project__name__in=target_projects
).values("permission_type__name", "permission_type__rank")
if not permissions:
return ""

max_permission = max(permissions, key=lambda p: p["permission_type__rank"])
return max_permission["permission_type__name"]

@classmethod
def has_field_permission(
cls,
operation: str,
requester_permission_type: str,
table_name: str,
field: dict,
) -> bool:
"""
Determine whether a user with a given permission type has access to a field
for a specific operation on a specific table.

Parameters:
operation (str): The type of operation ("get", "post", or "patch").
permission_type (str): The user's permission type
(e.g., "adminGlobal", "adminProject").
table_name (str): The name of the table/model to check (e.g., "User").
field (dict): A dictionary describing the field, including at least:
- "field_name"
- "table_name"
- operation-specific permission values
(e.g., {"get": "adminProject"}).

Returns:
bool: True if the permission type allows access to the field for the operation,
False otherwise.

Example:
>>> field_info = {
... "field_name": "email",
... "table_name": "User",
... "get": "adminProject"
... }
>>> has_field_permission("get", "adminProject", "User", field_info)
True
"""
operation_permission_type = field.get(operation, "")
if not operation_permission_type or field.get("table_name") != table_name:
return False

rank_dict = cls._get_rank_dict()
if (
requester_permission_type not in rank_dict
or operation_permission_type not in rank_dict
):
return False
return (
rank_dict[requester_permission_type] <= rank_dict[operation_permission_type]
)
29 changes: 29 additions & 0 deletions app/core/api/field_permissions.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
table_name,field_name,get,patch,post
User,uuid,memberProject,,
User,username,memberProject,,adminGlobal
User,is_active,,,
User,is_superuser,,,
User,is_staff,,,
User,first_name,memberProject,adminBrigade,adminGlobal
User,last_name,memberProject,adminBrigade,adminGlobal
User,email,memberProject,adminBrigade,adminGlobal
User,email_gmail,practiceLeadProject,adminBrigade,adminGlobal
User,email_preferred,practiceLeadProject,adminBrigade,adminGlobal
User,email_cognito,adminBrigade,adminBrigade,adminGlobal
User,created_at,adminProject,,,
User,job_title_current_intake,adminBrigade,adminBrigade,adminGlobal
User,job_title_target_intake,adminBrigade,adminBrigade,adminGlobal
User,current_skills,adminBrigade,adminBrigade,adminGlobal
User,target_skills,adminBrigade,adminBrigade,adminGlobal
User,linkedin_account,memberProject,adminBrigade,adminGlobal
User,github_handle,memberProject,adminBrigade,adminGlobal
User,phone,practiceLeadProject,adminBrigade,adminGlobal
User,texting_ok,practiceLeadProject,adminBrigade,adminGlobal
User,slack_id,memberProject,adminBrigade,adminGlobal
User,time_zone,memberProject,adminBrigade,adminGlobal
User,password,,adminBrigade,adminGlobal
User,last_login,adminProject,adminProject,,
User,practice_area_primary,adminProject,adminGlobal,adminGlobal
User,user_status,memberProject,adminBrigade,adminGlobal
User,updated_at,adminProject,,
User,referrer,memberProject,adminGlobal,adminGlobal
16 changes: 16 additions & 0 deletions app/core/api/has_user_permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from rest_framework.permissions import BasePermission

from .validate_request import validate_patch_fields
from .validate_request import validate_post_fields


class HasUserPermission(BasePermission):
def has_permission(self, request, view):
if request.method == "POST":
validate_post_fields(request=request, view=view)
return True # Default to allow the request

def has_object_permission(self, request, view, obj):
if request.method == "PATCH":
validate_patch_fields(obj=obj, request=request)
return True
72 changes: 72 additions & 0 deletions app/core/api/validate_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from rest_framework.exceptions import PermissionDenied
from rest_framework.exceptions import ValidationError

from core.models import User

from .access_control import AccessControl


def validate_post_fields(view, request):
# todo
serializer_class = view.serializer_class
table_name = serializer_class.Meta.model.__name__
permitted_fields = _get_permitted_fields_for_post_request(
request=request, table_name=table_name
)
_validate_request_fields_permitted(request, permitted_fields)


def get_fields_for_patch_request(request, table_name, response_related_user):
requesting_user = request.user
requesting_user = request.user
most_privileged_perm_type = AccessControl.get_highest_shared_project_perm_type(
requesting_user, response_related_user
)
fields = AccessControl.get_permitted_fields(
operation="patch",
table_name=table_name,
permission_type=most_privileged_perm_type,
)
return fields


def _get_permitted_fields_for_post_request(request, table_name):
highest_perm_type = AccessControl.get_highest_user_perm_type(request.user)
fields = AccessControl.get_permitted_fields(
operation="post",
table_name=table_name,
permission_type=highest_perm_type,
)
return fields


def _get_related_user_from_obj(obj):
if hasattr(obj, "user"):
return obj.user
elif isinstance(obj, User):
return obj
else:
raise ValueError("Cannot determine related user from the given object.")


def validate_patch_fields(request, obj):
table_name = obj.__class__.__name__
response_related_user = _get_related_user_from_obj(obj)
valid_fields = get_fields_for_patch_request(
table_name=table_name,
request=request,
response_related_user=response_related_user,
)
_validate_request_fields_permitted(request, valid_fields)


# @staticmethod
def _validate_request_fields_permitted(request, valid_fields) -> None:
"""Ensure the requesting user can patch the provided fields."""
request_fields_set = set(request.data)
permitted_fields_set = set(valid_fields)
notpermitted_fields = request_fields_set - permitted_fields_set
if not permitted_fields_set:
raise PermissionDenied("You do not have privileges ")
elif notpermitted_fields:
raise ValidationError(f"Invalid fields: {', '.join(notpermitted_fields)}")
4 changes: 3 additions & 1 deletion app/core/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from rest_framework.permissions import IsAuthenticated
from rest_framework.permissions import IsAuthenticatedOrReadOnly

from core.api.has_user_permissions import HasUserPermission

from ..models import Affiliate
from ..models import Affiliation
from ..models import CheckType
Expand Down Expand Up @@ -123,7 +125,7 @@ def get(self, request, *args, **kwargs):
partial_update=extend_schema(description="Partially update the given user"),
)
class UserViewSet(viewsets.ModelViewSet):
permission_classes = [IsAuthenticated]
permission_classes = [IsAuthenticated, HasUserPermission]
serializer_class = UserSerializer
lookup_field = "uuid"

Expand Down
25 changes: 14 additions & 11 deletions app/core/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from django.contrib.auth import get_user_model
from rest_framework.test import APIClient

from constants import admin_project
from constants import practice_lead_project
from constants import ADMIN_PROJECT
from constants import PRACTICE_LEAD_PROJECT
from test_data.utils.seed_constants import garry_name
from test_data.utils.seed_user import SeedUser

from ..models import Affiliate
from ..models import Affiliation
Expand All @@ -28,10 +31,15 @@
from ..models import StackElement
from ..models import StackElementType
from ..models import UrlType
from ..models import User
from ..models import UserPermission
from ..models import UserStatusType

collect_ignore = ["utils"]

# conftest.py

User = get_user_model()


@pytest.fixture
def user_superuser_admin():
Expand Down Expand Up @@ -71,7 +79,7 @@ def user_permission_admin_project():
username="TestUser Admin Project", email="TestUserAdminProject@example.com"
)
project = Project.objects.create(name="Test Project Admin Project")
permission_type = PermissionType.objects.filter(name=admin_project).first()
permission_type = PermissionType.objects.filter(name=ADMIN_PROJECT).first()
user_permission = UserPermission.objects.create(
user=user,
permission_type=permission_type,
Expand All @@ -87,7 +95,7 @@ def user_permission_practice_lead_project():
username="TestUser Practie Lead Project",
email="TestUserPracticeLeadProject@example.com",
)
permission_type = PermissionType.objects.filter(name=practice_lead_project).first()
permission_type = PermissionType.objects.filter(name=PRACTICE_LEAD_PROJECT).first()
project = Project.objects.create(name="Test Project Admin Project")
practice_area = PracticeArea.objects.first()
user_permission = UserPermission.objects.create(
Expand Down Expand Up @@ -120,12 +128,7 @@ def user2(django_user_model):

@pytest.fixture
def admin(django_user_model):
return django_user_model.objects.create_user(
is_staff=True,
username="TestAdminUser",
email="testadmin@email.com",
password="testadmin",
)
return SeedUser.get_user(garry_name)


@pytest.fixture
Expand Down
Loading