Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/tests/unit/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def test_providers_modules_should_have_tests(self):
"providers/google/tests/unit/google/test_go_module_utils.py",
"providers/google/tests/unit/google/test_version_compat.py",
"providers/http/tests/unit/http/test_exceptions.py",
"providers/keycloak/tests/unit/keycloak/auth_manager/datamodels/test_token.py",
"providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_adls.py",
"providers/microsoft/azure/tests/unit/microsoft/azure/test_version_compat.py",
"providers/openlineage/tests/unit/openlineage/test_version_compat.py",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from pydantic import Field

from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel


class TokenResponse(BaseModel):
"""Token serializer for responses."""

access_token: str


class TokenBody(StrictBaseModel):
"""Token serializer for post bodies."""

username: str = Field()
password: str = Field()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import requests
from fastapi import FastAPI
from keycloak import KeycloakOpenID

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
Expand All @@ -33,6 +34,7 @@
from airflow.providers.keycloak.auth_manager.cli.definition import KEYCLOAK_AUTH_MANAGER_COMMANDS
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
CONF_REALM_KEY,
CONF_SECTION_NAME,
CONF_SERVER_URL_KEY,
Expand Down Expand Up @@ -206,6 +208,7 @@ def filter_authorized_menu_items(

def get_fastapi_app(self) -> FastAPI | None:
from airflow.providers.keycloak.auth_manager.routes.login import login_router
from airflow.providers.keycloak.auth_manager.routes.token import token_router

app = FastAPI(
title="Keycloak auth manager sub application",
Expand All @@ -216,6 +219,7 @@ def get_fastapi_app(self) -> FastAPI | None:
),
)
app.include_router(login_router)
app.include_router(token_router)

return app

Expand All @@ -230,6 +234,20 @@ def get_cli_commands() -> list[CLICommand]:
),
]

@staticmethod
def get_keycloak_client() -> KeycloakOpenID:
client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY)
client_secret = conf.get(CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY)
realm = conf.get(CONF_SECTION_NAME, CONF_REALM_KEY)
server_url = conf.get(CONF_SECTION_NAME, CONF_SERVER_URL_KEY)

return KeycloakOpenID(
server_url=server_url,
client_id=client_id,
client_secret_key=client_secret,
realm_name=realm,
)

def _is_authorized(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,13 @@
import logging

from fastapi import Request # noqa: TC002
from keycloak import KeycloakOpenID
from starlette.responses import HTMLResponse, RedirectResponse

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.configuration import conf
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
CONF_REALM_KEY,
CONF_SECTION_NAME,
CONF_SERVER_URL_KEY,
)
from airflow.providers.keycloak.auth_manager.keycloak_auth_manager import KeycloakAuthManager
from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser

log = logging.getLogger(__name__)
Expand All @@ -43,7 +36,7 @@
@login_router.get("/login")
def login(request: Request) -> RedirectResponse:
"""Initiate the authentication."""
client = _get_keycloak_client()
client = KeycloakAuthManager.get_keycloak_client()
redirect_uri = request.url_for("login_callback")
auth_url = client.auth_url(redirect_uri=str(redirect_uri), scope="openid")
return RedirectResponse(auth_url)
Expand All @@ -56,7 +49,7 @@ def login_callback(request: Request):
if not code:
return HTMLResponse("Missing code", status_code=400)

client = _get_keycloak_client()
client = KeycloakAuthManager.get_keycloak_client()
redirect_uri = request.url_for("login_callback")

tokens = client.token(
Expand All @@ -77,17 +70,3 @@ def login_callback(request: Request):
secure = bool(conf.get("api", "ssl_cert", fallback=""))
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
return response


def _get_keycloak_client() -> KeycloakOpenID:
client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY)
client_secret = conf.get(CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY)
realm = conf.get(CONF_SECTION_NAME, CONF_REALM_KEY)
server_url = conf.get(CONF_SECTION_NAME, CONF_SERVER_URL_KEY)

return KeycloakOpenID(
server_url=server_url,
client_id=client_id,
client_secret_key=client_secret,
realm_name=realm,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging

from fastapi import HTTPException
from keycloak import KeycloakAuthenticationError
from starlette import status

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.providers.keycloak.auth_manager.datamodels.token import TokenBody, TokenResponse
from airflow.providers.keycloak.auth_manager.keycloak_auth_manager import KeycloakAuthManager
from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser

log = logging.getLogger(__name__)
token_router = AirflowRouter(tags=["KeycloakAuthManagerToken"])


@token_router.post(
"/token",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_401_UNAUTHORIZED]),
)
def create_token(body: TokenBody) -> TokenResponse:
client = KeycloakAuthManager.get_keycloak_client()

try:
tokens = client.token(body.username, body.password)
except KeycloakAuthenticationError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)

userinfo = client.userinfo(tokens["access_token"])
user = KeycloakAuthManagerUser(
user_id=userinfo["sub"],
name=userinfo["preferred_username"],
access_token=tokens["access_token"],
refresh_token=tokens["refresh_token"],
)
token = get_auth_manager().generate_jwt(user)

return TokenResponse(access_token=token)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import create_app
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
CONF_REALM_KEY,
CONF_SECTION_NAME,
)

from tests_common.test_utils.config import conf_vars


@pytest.fixture
def client():
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager",
(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY): "test",
(CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY): "test",
(CONF_SECTION_NAME, CONF_REALM_KEY): "test",
(CONF_SECTION_NAME, "base_url"): "http://host.docker.internal:48080",
}
):
yield TestClient(create_app())
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,11 @@

from unittest.mock import ANY, Mock, patch

import pytest
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, create_app
from airflow.providers.keycloak.auth_manager.constants import (
CONF_CLIENT_ID_KEY,
CONF_CLIENT_SECRET_KEY,
CONF_REALM_KEY,
CONF_SECTION_NAME,
)

from tests_common.test_utils.config import conf_vars


@pytest.fixture
def client():
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.providers.keycloak.auth_manager.keycloak_auth_manager.KeycloakAuthManager",
(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY): "test",
(CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY): "test",
(CONF_SECTION_NAME, CONF_REALM_KEY): "test",
(CONF_SECTION_NAME, "base_url"): "http://host.docker.internal:48080",
}
):
yield TestClient(create_app())
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX


class TestLoginRouter:
@patch("airflow.providers.keycloak.auth_manager.routes.login._get_keycloak_client")
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_login(self, mock_get_keycloak_client, client):
redirect_url = "redirect_url"
mock_keycloak_client = Mock()
Expand All @@ -62,7 +34,7 @@ def test_login(self, mock_get_keycloak_client, client):
assert response.headers["location"] == redirect_url

@patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager")
@patch("airflow.providers.keycloak.auth_manager.routes.login._get_keycloak_client")
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_login_callback(self, mock_get_keycloak_client, mock_get_auth_manager, client):
code = "code"
token = "token"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch

from keycloak import KeycloakAuthenticationError

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX


class TestTokenRouter:
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_create_token(self, mock_get_keycloak_client, client):
mock_keycloak_client = Mock()
mock_keycloak_client.token.return_value = {
"access_token": "access_token",
"refresh_token": "refresh_token",
}
mock_keycloak_client.userinfo.return_value = {"sub": "sub", "preferred_username": "username"}
mock_get_keycloak_client.return_value = mock_keycloak_client
response = client.post(
AUTH_MANAGER_FASTAPI_APP_PREFIX + "/token",
json={"username": "username", "password": "password"},
)

assert response.status_code == 201
mock_keycloak_client.token.assert_called_once_with("username", "password")
mock_keycloak_client.userinfo.assert_called_once_with("access_token")

@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_create_token_with_invalid_creds(self, mock_get_keycloak_client, client):
mock_keycloak_client = Mock()
mock_keycloak_client.token.side_effect = KeycloakAuthenticationError()
mock_get_keycloak_client.return_value = mock_keycloak_client
response = client.post(
AUTH_MANAGER_FASTAPI_APP_PREFIX + "/token",
json={"username": "username", "password": "password"},
)

assert response.status_code == 401
mock_keycloak_client.token.assert_called_once_with("username", "password")