Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@

from airflow.api_fastapi.common.types import MenuItem
from airflow.cli.cli_config import CLICommand
from airflow.providers.common.compat.sdk import AirflowException, conf

try:
from airflow.providers.common.compat.sdk import AirflowException, conf
except ModuleNotFoundError:
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
Expand Down Expand Up @@ -76,6 +81,14 @@
log = logging.getLogger(__name__)

RESOURCE_ID_ATTRIBUTE_NAME = "resource_id"
TEAM_SCOPED_RESOURCES = frozenset(
{
KeycloakResource.DAG,
KeycloakResource.CONNECTION,
KeycloakResource.POOL,
KeycloakResource.VARIABLE,
}
)


class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
Expand Down Expand Up @@ -184,10 +197,7 @@ def is_authorized_configuration(
) -> bool:
config_section = details.section if details else None
return self._is_authorized(
method=method,
resource_type=KeycloakResource.CONFIGURATION,
user=user,
resource_id=config_section,
method=method, resource_type=KeycloakResource.CONFIGURATION, user=user, resource_id=config_section
)

def is_authorized_connection(
Expand All @@ -198,8 +208,13 @@ def is_authorized_connection(
details: ConnectionDetails | None = None,
) -> bool:
connection_id = details.conn_id if details else None
team_name = self._get_team_name(details)
return self._is_authorized(
method=method, resource_type=KeycloakResource.CONNECTION, user=user, resource_id=connection_id
method=method,
resource_type=KeycloakResource.CONNECTION,
user=user,
resource_id=connection_id,
team_name=team_name,
)

def is_authorized_dag(
Expand All @@ -211,12 +226,14 @@ def is_authorized_dag(
details: DagDetails | None = None,
) -> bool:
dag_id = details.id if details else None
team_name = self._get_team_name(details)
access_entity_str = access_entity.value if access_entity else None
return self._is_authorized(
method=method,
resource_type=KeycloakResource.DAG,
user=user,
resource_id=dag_id,
team_name=team_name,
attributes={"dag_entity": access_entity_str},
)

Expand Down Expand Up @@ -262,16 +279,26 @@ def is_authorized_variable(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: VariableDetails | None = None
) -> bool:
variable_key = details.key if details else None
team_name = self._get_team_name(details)
return self._is_authorized(
method=method, resource_type=KeycloakResource.VARIABLE, user=user, resource_id=variable_key
method=method,
resource_type=KeycloakResource.VARIABLE,
user=user,
resource_id=variable_key,
team_name=team_name,
)

def is_authorized_pool(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: PoolDetails | None = None
) -> bool:
pool_name = details.name if details else None
team_name = self._get_team_name(details)
return self._is_authorized(
method=method, resource_type=KeycloakResource.POOL, user=user, resource_id=pool_name
method=method,
resource_type=KeycloakResource.POOL,
user=user,
resource_id=pool_name,
team_name=team_name,
)

def is_authorized_view(self, *, access_view: AccessView, user: KeycloakAuthManagerUser) -> bool:
Expand Down Expand Up @@ -356,6 +383,7 @@ def _is_authorized(
resource_type: KeycloakResource,
user: KeycloakAuthManagerUser,
resource_id: str | None = None,
team_name: str | None = None,
attributes: dict[str, str | None] | None = None,
) -> bool:
client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY)
Expand All @@ -368,9 +396,19 @@ def _is_authorized(
elif method == "GET":
method = "LIST"

if (
team_name
and conf.getboolean("core", "multi_team", fallback=False)
and resource_type in TEAM_SCOPED_RESOURCES
):
resource_name = f"{resource_type.value}:{team_name}"
else:
resource_name = resource_type.value
permission = f"{resource_name}#{method}"

resp = self.http_session.post(
self._get_token_url(server_url, realm),
data=self._get_payload(client_id, f"{resource_type.value}#{method}", context_attributes),
data=self._get_payload(client_id, permission, context_attributes),
headers=self._get_headers(user.access_token),
timeout=5,
)
Expand Down Expand Up @@ -425,6 +463,12 @@ def _get_token_url(server_url, realm):
# Normalize server_url to avoid double slashes (required for Keycloak 26.4+ strict path validation).
return f"{server_url.rstrip('/')}/realms/{realm}/protocol/openid-connect/token"

@staticmethod
def _get_team_name(
details: ConnectionDetails | DagDetails | PoolDetails | VariableDetails | None,
) -> str | None:
return getattr(details, "team_name", None) if details else None

@staticmethod
def _get_payload(client_id: str, permission: str, attributes: dict[str, str] | None = None):
payload: dict[str, Any] = {
Expand Down
Loading