Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2676,3 +2676,15 @@ dag_processor:
type: integer
example: ~
default: "30"
fastapi:
description: Configuration for the Fastapi webserver.
options:
base_url:
description: |
The base url of the Fastapi endpoint. Airflow cannot guess what domain or CNAME you are using.
If the Airflow console (the front-end) and the Fastapi apis are on a different domain, this config
should contain the Fastapi apis endpoint.
version_added: ~
type: string
example: ~
default: "http://localhost:29091"
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from collections import defaultdict
from collections.abc import Container, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

from flask import session, url_for
from fastapi import FastAPI
from flask import session

from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.resource_details import (
Expand All @@ -34,6 +35,7 @@
VariableDetails,
)
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
from airflow.configuration import conf
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
Expand All @@ -43,11 +45,7 @@
from airflow.providers.amazon.aws.auth_manager.cli.definition import (
AWS_AUTH_MANAGER_COMMANDS,
)
from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import (
AwsSecurityManagerOverride,
)
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
from airflow.providers.amazon.aws.auth_manager.views.auth import AwsAuthManagerAuthenticationViews
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
Expand All @@ -61,7 +59,6 @@
IsAuthorizedVariableRequest,
)
from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder


class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
Expand All @@ -72,8 +69,6 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
authentication and authorization in Airflow.
"""

appbuilder: AirflowAppBuilder | None = None

def __init__(self) -> None:
if not AIRFLOW_V_3_0_PLUS:
raise AirflowOptionalProviderFeatureException(
Expand All @@ -87,12 +82,27 @@ def __init__(self) -> None:
def avp_facade(self):
return AwsAuthManagerAmazonVerifiedPermissionsFacade()

@cached_property
def fastapi_endpoint(self) -> str:
return conf.get("fastapi", "base_url")

def get_user(self) -> AwsAuthManagerUser | None:
return session["aws_user"] if self.is_logged_in() else None

def is_logged_in(self) -> bool:
return "aws_user" in session

def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
return AwsAuthManagerUser(**token)

def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
return {
"user_id": user.get_id(),
"groups": user.get_groups(),
"username": user.username,
"email": user.email,
}

def is_authorized_configuration(
self,
*,
Expand Down Expand Up @@ -367,14 +377,10 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest):
return accessible_items

def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")
return f"{self.fastapi_endpoint}/auth/login"

def get_url_logout(self) -> str:
return url_for("AwsAuthManagerAuthenticationViews.logout")

@cached_property
def security_manager(self) -> AwsSecurityManagerOverride:
return AwsSecurityManagerOverride(self.appbuilder)
raise NotImplementedError()

@staticmethod
def get_cli_commands() -> list[CLICommand]:
Expand All @@ -387,9 +393,20 @@ def get_cli_commands() -> list[CLICommand]:
),
]

def register_views(self) -> None:
if self.appbuilder:
self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
def get_fastapi_app(self) -> FastAPI | None:
from airflow.providers.amazon.aws.auth_manager.router.login import login_router

app = FastAPI(
title="AWS auth manager sub application",
description=(
"This is the AWS auth manager fastapi sub application. This API is only available if the "
"auth manager used in the Airflow environment is AWS auth manager. "
"This sub application provides login routes."
),
)
app.include_router(login_router)

return app

@staticmethod
def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import Any

import anyio
from fastapi import HTTPException, Request
from starlette import status
from starlette.responses import RedirectResponse

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.configuration import conf
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser

try:
from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.errors import OneLogin_Saml2_Error
from onelogin.saml2.idp_metadata_parser import OneLogin_Saml2_IdPMetadataParser
except ImportError:
raise ImportError(
"AWS auth manager requires the python3-saml library but it is not installed by default. "
"Please install the python3-saml library by running: "
"pip install apache-airflow-providers-amazon[python3-saml]"
)

log = logging.getLogger(__name__)
login_router = AirflowRouter(tags=["AWSAuthManagerLogin"])


@login_router.get("/login")
def login(request: Request):
"""Authenticate the user."""
saml_auth = _init_saml_auth(request)
callback_url = saml_auth.login()
return RedirectResponse(url=callback_url)


@login_router.post("/login_callback")
def login_callback(request: Request):
"""Authenticate the user."""
saml_auth = _init_saml_auth(request)
try:
saml_auth.process_response()
except OneLogin_Saml2_Error as e:
log.exception(e)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to authenticate")
errors = saml_auth.get_errors()
is_authenticated = saml_auth.is_authenticated()
if not is_authenticated:
error_reason = saml_auth.get_last_error_reason()
log.error("Failed to authenticate")
log.error("Errors: %s", errors)
log.error("Error reason: %s", error_reason)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to authenticate: {error_reason}")

attributes = saml_auth.get_attributes()
user = AwsAuthManagerUser(
user_id=attributes["id"][0],
groups=attributes["groups"],
username=saml_auth.get_nameid(),
email=attributes["email"][0] if "email" in attributes else None,
)
return RedirectResponse(url=f"/webapp?token={get_auth_manager().get_jwt_token(user)}", status_code=303)


def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
request_data = _prepare_request(request)
base_url = conf.get(section="fastapi", key="base_url")
settings = {
# We want to keep this flag on in case of errors.
# It provides an error reasons, if turned off, it does not
"debug": True,
"sp": {
"entityId": "aws-auth-manager-saml-client",
"assertionConsumerService": {
"url": f"{base_url}/auth/login_callback",
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
},
},
}
merged_settings = OneLogin_Saml2_IdPMetadataParser.merge_settings(_get_idp_data(), settings)
return OneLogin_Saml2_Auth(request_data, merged_settings)


def _prepare_request(request: Request) -> dict:
host = request.headers.get("host", request.client.host if request.client else "localhost")
data: dict[str, Any] = {
"https": "on" if request.url.scheme == "https" else "off",
"http_host": host,
"server_port": request.url.port,
"script_name": request.url.path,
"get_data": request.query_params,
"post_data": {},
}
form_data = anyio.from_thread.run(request.form)
if "SAMLResponse" in form_data:
data["post_data"]["SAMLResponse"] = form_data["SAMLResponse"]
if "RelayState" in form_data:
data["post_data"]["RelayState"] = form_data["RelayState"]
return data


def _get_idp_data() -> dict:
saml_metadata_url = conf.get_mandatory_value(CONF_SECTION_NAME, CONF_SAML_METADATA_URL_KEY)
return OneLogin_Saml2_IdPMetadataParser.parse_remote(saml_metadata_url)
Loading