Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from enum import Enum
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

Expand Down Expand Up @@ -70,12 +71,36 @@
)
from airflow.cli.cli_config import CLICommand

# This cannot be in the TYPE_CHECKING block since some providers import it globally.
# TODO: Move this inside once all providers drop Airflow 2.x support.
# List of methods (or actions) a user can do against a resource
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE"]
# Extends ``ResourceMethod`` to include "MENU". The method "MENU" is only supported with specific resources (menu items)
ExtendedResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
if TYPE_CHECKING:
# For static type checking - accepts string literals
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE"]
ExtendedResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
else:
# For runtime - provides iteration and validation

class ResourceMethod(str, Enum):
"""HTTP methods (actions) a user can perform against a resource."""

GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"

def __str__(self) -> str:
return self.value

class ExtendedResourceMethod(str, Enum):
"""Extended HTTP methods including MENU for UI resource authorization."""

GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
MENU = "MENU"

def __str__(self) -> str:
return self.value


log = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseUser)
Expand Down Expand Up @@ -322,7 +347,7 @@ def is_authorized_view(
"""

@abstractmethod
def is_authorized_custom_view(self, *, method: ResourceMethod | str, resource_name: str, user: T) -> bool:
def is_authorized_custom_view(self, *, method: ResourceMethod, resource_name: str, user: T) -> bool:
"""
Return whether the user is authorized to perform a given action on a custom view.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def is_authorized_view(self, *, access_view: AccessView, user: SimpleAuthManager
return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user)

def is_authorized_custom_view(
self, *, method: ResourceMethod | str, resource_name: str, user: SimpleAuthManagerUser
self, *, method: ResourceMethod, resource_name: str, user: SimpleAuthManagerUser
):
return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@

import json
import logging
from enum import Enum
from typing import get_args

from keycloak import KeycloakAdmin, KeycloakError

from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod

try:
from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod
except ImportError:
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod as ExtendedResourceMethod
from airflow.api_fastapi.common.types import MenuItem
from airflow.providers.common.compat.sdk import conf
from airflow.providers.keycloak.auth_manager.cli.utils import dry_run_message_wrap, dry_run_preview
Expand All @@ -41,9 +37,41 @@
from airflow.utils import cli as cli_utils
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

try:
from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod
except ImportError:
# Fallback for older Airflow versions where ExtendedResourceMethod doesn't exist
from airflow.api_fastapi.auth.managers.base_auth_manager import (
ResourceMethod as ExtendedResourceMethod, # type: ignore[assignment]
)

log = logging.getLogger(__name__)


def _get_resource_methods() -> list[str]:
"""
Get list of resource method values.

Provides backwards compatibility for Airflow <3.2 where ResourceMethod
was a Literal type, and Airflow >=3.2 where it's an Enum.
"""
if isinstance(ResourceMethod, type) and issubclass(ResourceMethod, Enum):
return [method.value for method in ResourceMethod]
return list(get_args(ResourceMethod))


def _get_extended_resource_methods() -> list[str]:
"""
Get list of extended resource method values.

Provides backwards compatibility for Airflow <3.2 where ExtendedResourceMethod
was a Literal type, and Airflow >=3.2 where it's an Enum.
"""
if isinstance(ExtendedResourceMethod, type) and issubclass(ExtendedResourceMethod, Enum):
return [method.value for method in ExtendedResourceMethod]
return list(get_args(ExtendedResourceMethod))


@cli_utils.action_cli
@providers_configuration_loaded
@dry_run_message_wrap
Expand Down Expand Up @@ -119,7 +147,7 @@ def _get_client_uuid(args):

def _get_scopes_to_create() -> list[dict]:
"""Get the list of scopes to be created."""
scopes = [{"name": method} for method in get_args(ResourceMethod)]
scopes = [{"name": method} for method in _get_resource_methods()]
scopes.extend([{"name": "MENU"}, {"name": "LIST"}])
return scopes

Expand Down Expand Up @@ -231,7 +259,7 @@ def _get_permissions_to_create(client: KeycloakAdmin, client_uuid: str) -> list[
{
"name": "Admin",
"type": "scope-based",
"scope_names": list(get_args(ExtendedResourceMethod)) + ["LIST"],
"scope_names": _get_extended_resource_methods() + ["LIST"],
},
{
"name": "User",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from __future__ import annotations

import importlib
from typing import get_args
from unittest.mock import Mock, call, patch

import pytest

from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
from airflow.api_fastapi.common.types import MenuItem
from airflow.cli import cli_parser
from airflow.providers.keycloak.auth_manager.cli.commands import (
_get_resource_methods,
create_all_command,
create_permissions_command,
create_resources_command,
Expand Down Expand Up @@ -82,7 +81,7 @@ def test_create_scopes(self, mock_get_client):
create_scopes_command(self.arg_parser.parse_args(params))

client.get_clients.assert_called_once_with()
scopes = [{"name": method} for method in get_args(ResourceMethod)]
scopes = [{"name": method} for method in _get_resource_methods()]
calls = [call(client_id="test-id", payload=scope) for scope in scopes]
client.create_client_authz_scopes.assert_has_calls(calls)

Expand Down
Loading