Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager

# Define the path in which the potential auth manager fastapi is mounted
AUTH_MANAGER_FASTAPI_APP_PREFIX = "/auth"

log = logging.getLogger(__name__)

app: FastAPI | None = None
Expand Down Expand Up @@ -141,7 +144,7 @@ def init_auth_manager(app: FastAPI | None = None) -> BaseAuthManager:
am.init()

if app and (auth_manager_fastapi_app := am.get_fastapi_app()):
app.mount("/auth", auth_manager_fastapi_app)
app.mount(AUTH_MANAGER_FASTAPI_APP_PREFIX, auth_manager_fastapi_app)
app.state.auth_manager = am

return am
Expand Down
9 changes: 9 additions & 0 deletions airflow/api_fastapi/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
from airflow.api_fastapi.auth.managers.models.resource_details import DagDetails
from airflow.api_fastapi.common.types import MenuItem
from airflow.configuration import conf
from airflow.models import DagModel
from airflow.typing_compat import Literal
Expand Down Expand Up @@ -411,6 +412,14 @@ def get_fastapi_app(self) -> FastAPI | None:
"""
return None

def get_menu_items(self, *, user: T) -> list[MenuItem]:
"""
Provide additional links to be added to the menu.

:param user: the user
"""
return []

@staticmethod
def _get_token_signer(
expiration_time_in_seconds: int = conf.getint("api", "auth_jwt_expiration_time"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from starlette.templating import Jinja2Templates
from termcolor import colored

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.configuration import AIRFLOW_HOME, conf
Expand Down Expand Up @@ -131,9 +132,9 @@ def get_url_login(self, **kwargs) -> str:
"""Return the login page url."""
is_simple_auth_manager_all_admins = conf.getboolean("core", "simple_auth_manager_all_admins")
if is_simple_auth_manager_all_admins:
return "/auth/token"
return AUTH_MANAGER_FASTAPI_APP_PREFIX + "/token"

return "/auth/login"
return AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login"

def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser:
return SimpleAuthManagerUser(username=token["username"], role=token["role"])
Expand Down
9 changes: 9 additions & 0 deletions airflow/api_fastapi/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import Annotated
Expand Down Expand Up @@ -72,3 +73,11 @@ class Mimetype(str, Enum):
TEXT = "text/plain"
JSON = "application/json"
ANY = "*/*"


@dataclass
class MenuItem:
"""Define a menu item."""

text: str
href: str
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from fastapi import FastAPI

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
Expand Down Expand Up @@ -322,7 +323,7 @@ def _has_access_to_dag(request: IsAuthorizedRequest):
return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}

def get_url_login(self, **kwargs) -> str:
return urljoin(self.apiserver_endpoint, "auth/login")
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")

@staticmethod
def get_cli_commands() -> list[CLICommand]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from starlette import status
from starlette.responses import RedirectResponse

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, get_auth_manager
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.configuration import conf
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
Expand Down Expand Up @@ -94,7 +94,7 @@ def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
"sp": {
"entityId": "aws-auth-manager-saml-client",
"assertionConsumerService": {
"url": f"{base_url}/auth/login_callback",
"url": f"{base_url}{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login_callback",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@

import boto3
import pytest

from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS

if not AIRFLOW_V_3_0_PLUS:
pytest.skip("AWS auth manager is only compatible with Airflow >= 3.0.0", allow_module_level=True)

from fastapi.testclient import TestClient
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser

from airflow.api_fastapi.app import create_app
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, create_app
from system.amazon.aws.utils import set_env_id

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -191,7 +197,9 @@ def delete_avp_policy_store(cls):
client.delete_policy_store(policyStoreId=policy_store_id)

def test_login_admin(self, client_admin_permissions):
response = client_admin_permissions.post("/auth/login_callback", follow_redirects=False)
response = client_admin_permissions.post(
AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login_callback", follow_redirects=False
)
assert response.status_code == 303
assert "location" in response.headers
assert "/?token=" in response.headers["location"]
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from fastapi.testclient import TestClient
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser

from airflow.api_fastapi.app import create_app
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, create_app

from tests_common.test_utils.config import conf_vars

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_client():

class TestLoginRouter:
def test_login(self, test_client):
response = test_client.get("/auth/login", follow_redirects=False)
response = test_client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login", follow_redirects=False)
assert response.status_code == 307
assert "location" in response.headers
assert response.headers["location"].startswith(
Expand Down Expand Up @@ -114,7 +114,9 @@ def test_login_callback_successful(self):
}
mock_init_saml_auth.return_value = auth
client = TestClient(create_app())
response = client.post("/auth/login_callback", follow_redirects=False)
response = client.post(
AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login_callback", follow_redirects=False
)
assert response.status_code == 303
assert "location" in response.headers
assert response.headers["location"].startswith("http://localhost:8080/?token=")
Expand Down Expand Up @@ -145,5 +147,5 @@ def test_login_callback_unsuccessful(self):
auth.is_authenticated.return_value = False
mock_init_saml_auth.return_value = auth
client = TestClient(create_app())
response = client.post("/auth/login_callback")
response = client.post(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login_callback")
assert response.status_code == 500
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
if not AIRFLOW_V_3_0_PLUS:
pytest.skip("AWS auth manager is only compatible with Airflow >= 3.0.0", allow_module_level=True)

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
Expand Down Expand Up @@ -569,7 +570,7 @@ def test_filter_permitted_dag_ids(self, method, user, auth_manager, test_user, e

def test_get_url_login(self, auth_manager):
result = auth_manager.get_url_login()
assert result == "http://localhost:8080/auth/login"
assert result == f"http://localhost:8080{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login"

def test_get_cli_commands_return_cli_commands(self, auth_manager):
assert len(auth_manager.get_cli_commands()) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from starlette.middleware.wsgi import WSGIMiddleware

from airflow import __version__ as airflow_version
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
Expand All @@ -43,6 +44,7 @@
PoolDetails,
VariableDetails,
)
from airflow.api_fastapi.common.types import MenuItem
from airflow.cli.cli_config import (
DefaultHelpParser,
GroupCommand,
Expand Down Expand Up @@ -410,7 +412,7 @@ def security_manager(self) -> FabAirflowSecurityManagerOverride:

def get_url_login(self, **kwargs) -> str:
"""Return the login page url."""
return urljoin(self.apiserver_endpoint, "auth/login/")
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login/")

def get_url_logout(self):
"""Return the logout page url."""
Expand All @@ -425,6 +427,49 @@ def logout(self) -> None:
def register_views(self) -> None:
self.security_manager.register_views()

def get_menu_items(self, *, user: User) -> list[MenuItem]:
# Contains the list of menu items. ``resource_type`` is the name of the resource in FAB
# permission model to check whether the user is allowed to see this menu item
items = [
{
"resource_type": "List Users",
"text": "Users",
"href": AUTH_MANAGER_FASTAPI_APP_PREFIX
+ url_for(f"{self.security_manager.user_view.__class__.__name__}.list", _external=False),
},
{
"resource_type": "List Roles",
"text": "Roles",
"href": AUTH_MANAGER_FASTAPI_APP_PREFIX
+ url_for("CustomRoleModelView.list", _external=False),
},
{
"resource_type": "Actions",
"text": "Actions",
"href": AUTH_MANAGER_FASTAPI_APP_PREFIX + url_for("ActionModelView.list", _external=False),
},
{
"resource_type": "Resources",
"text": "Resources",
"href": AUTH_MANAGER_FASTAPI_APP_PREFIX + url_for("ResourceModelView.list", _external=False),
},
{
"resource_type": "Permission Pairs",
"text": "Permissions",
"href": AUTH_MANAGER_FASTAPI_APP_PREFIX
+ url_for(
"PermissionPairModelView.list",
_external=False,
),
},
]

return [
MenuItem(text=item["text"], href=item["href"])
for item in items
if self._is_authorized(method="MENU", resource_type=item["resource_type"], user=user)
]

def _is_authorized(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
from flask import Flask, g

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -98,6 +99,8 @@ def flask_app():

@pytest.fixture
def auth_manager_with_appbuilder(flask_app):
flask_app.config["AUTH_RATE_LIMITED"] = False
flask_app.config["SERVER_NAME"] = "localhost"
appbuilder = init_appbuilder(flask_app, enable_plugins=False)
auth_manager = FabAuthManager()
auth_manager.appbuilder = appbuilder
Expand Down Expand Up @@ -561,7 +564,7 @@ class TestSecurityManager:

def test_get_url_login(self, auth_manager):
result = auth_manager.get_url_login()
assert result == "http://localhost:8080/auth/login/"
assert result == f"http://localhost:8080{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login/"

@pytest.mark.db_test
def test_get_url_logout_when_auth_view_not_defined(self, auth_manager_with_appbuilder):
Expand All @@ -581,3 +584,11 @@ def test_get_url_logout(self, mock_url_for, auth_manager_with_appbuilder):
def test_logout(self, mock_logout_user, auth_manager_with_appbuilder):
auth_manager_with_appbuilder.logout()
mock_logout_user.assert_called_once()

@mock.patch.object(FabAuthManager, "_is_authorized", return_value=True)
def test_get_menu_items(self, _, auth_manager_with_appbuilder, flask_app):
with flask_app.app_context():
auth_manager_with_appbuilder.register_views()
result = auth_manager_with_appbuilder.get_menu_items(user=Mock())
assert len(result) == 5
assert all(item.href.startswith(AUTH_MANAGER_FASTAPI_APP_PREFIX) for item in result)
6 changes: 4 additions & 2 deletions scripts/in_container/run_update_fastapi_api_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import yaml
from fastapi.openapi.utils import get_openapi

from airflow.api_fastapi.app import create_app
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, create_app
from airflow.api_fastapi.auth.managers.simple import __file__ as SIMPLE_AUTH_MANAGER_PATH
from airflow.api_fastapi.auth.managers.simple.simple_auth_manager import SimpleAuthManager
from airflow.api_fastapi.core_api import __file__ as CORE_API_PATH
Expand Down Expand Up @@ -76,7 +76,9 @@ def generate_file(app: FastAPI, file_path: Path, prefix: str = ""):
simple_auth_manager_app = SimpleAuthManager().get_fastapi_app()
if simple_auth_manager_app:
generate_file(
app=simple_auth_manager_app, file_path=SIMPLE_AUTH_MANAGER_OPENAPI_SPEC_FILE, prefix="/auth"
app=simple_auth_manager_app,
file_path=SIMPLE_AUTH_MANAGER_OPENAPI_SPEC_FILE,
prefix=AUTH_MANAGER_FASTAPI_APP_PREFIX,
)

# Generate FAB auth manager openapi spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.models.resource_details import AccessView
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser

Expand Down Expand Up @@ -65,12 +66,12 @@ def test_init_with_all_admins(self, auth_manager):

def test_get_url_login(self, auth_manager):
result = auth_manager.get_url_login()
assert result == "/auth/login"
assert result == AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login"

def test_get_url_login_with_all_admins(self, auth_manager):
with conf_vars({("core", "simple_auth_manager_all_admins"): "true"}):
result = auth_manager.get_url_login()
assert result == "/auth/token"
assert result == AUTH_MANAGER_FASTAPI_APP_PREFIX + "/token"

def test_deserialize_user(self, auth_manager):
result = auth_manager.deserialize_user({"username": "test", "role": "admin"})
Expand Down
3 changes: 3 additions & 0 deletions tests/api_fastapi/auth/managers/test_base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def test_get_fastapi_app_return_none(self, auth_manager):
def test_logout_return_none(self, auth_manager):
assert auth_manager.logout() is None

def test_get_menu_items_return_empty_list(self, auth_manager):
assert auth_manager.get_menu_items(user=BaseAuthManagerUserTest(name="test")) == []

@patch("airflow.api_fastapi.auth.managers.base_auth_manager.JWTSigner")
@patch.object(EmptyAuthManager, "deserialize_user")
def test_get_user_from_token(self, mock_deserialize_user, mock_jwt_signer, auth_manager):
Expand Down
Loading