Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ paths:
content:
application/json:
schema: {}
security:
- OAuth2PasswordBearer: []
- HTTPBearer: []
/auth/logout_callback:
get:
tags:
Expand Down Expand Up @@ -106,8 +103,8 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'401':
description: Unauthorized
'403':
description: Forbidden
content:
application/json:
schema:
Expand Down Expand Up @@ -143,8 +140,8 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'401':
description: Unauthorized
'403':
description: Forbidden
content:
application/json:
schema:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
log = logging.getLogger(__name__)
login_router = AirflowRouter(tags=["KeycloakAuthManagerLogin"])

COOKIE_NAME_ID_TOKEN = "_id_token"


@login_router.get("/login")
def login(request: Request) -> RedirectResponse:
Expand Down Expand Up @@ -77,26 +79,30 @@
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure, httponly=True)
else:
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)

# Save id token as separate cookie.
# Cookies have a size limit (usually 4k), saving all the tokens in a same cookie goes beyond this limit
response.set_cookie(COOKIE_NAME_ID_TOKEN, tokens["id_token"], secure=secure, httponly=True)

return response


@login_router.get("/logout")
def logout(request: Request, user: Annotated[KeycloakAuthManagerUser, Depends(get_user)]):
def logout(request: Request):
"""Log out the user from Keycloak."""
auth_manager = cast("KeycloakAuthManager", get_auth_manager())
keycloak_config = auth_manager.get_keycloak_client().well_known()
end_session_endpoint = keycloak_config["end_session_endpoint"]

# Use the refresh flow to get the id token, it avoids us to save the id token
token_id = auth_manager.refresh_tokens(user=user).get("id_token")
id_token = request.cookies.get(COOKIE_NAME_ID_TOKEN)
post_logout_redirect_uri = request.url_for("logout_callback")

if token_id:
logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}&id_token_hint={token_id}"
if id_token:
logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}&id_token_hint={id_token}"
else:
logout_url = f"{end_session_endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}"
logout_url = str(post_logout_redirect_uri)

return RedirectResponse(logout_url)

Check warning

Code scanning / CodeQL

URL redirection from remote source Medium

Untrusted URL redirection depends on a
user-provided value
.


@login_router.get("/logout_callback")
Expand All @@ -114,6 +120,11 @@
secure=secure,
httponly=True,
)
response.delete_cookie(
key=COOKIE_NAME_ID_TOKEN,
secure=secure,
httponly=True,
)
return response


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
@token_router.post(
"/token",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_401_UNAUTHORIZED]),
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_403_FORBIDDEN]),
)
def create_token(body: TokenBody) -> TokenResponse:
token = body.root.create_token(
Expand All @@ -49,7 +49,7 @@ def create_token(body: TokenBody) -> TokenResponse:
@token_router.post(
"/token/cli",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_401_UNAUTHORIZED]),
responses=create_openapi_http_exception_doc([status.HTTP_400_BAD_REQUEST, status.HTTP_403_FORBIDDEN]),
)
def create_token_cli(body: TokenPasswordBody) -> TokenResponse:
token = body.create_token(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def create_token_for(
tokens = client.token(username, password)
except KeycloakAuthenticationError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)

Expand Down Expand Up @@ -77,7 +77,7 @@ def create_client_credentials_token(
tokens = client.token(grant_type="client_credentials")
except KeycloakAuthenticationError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_403_FORBIDDEN,
detail="Client credentials authentication failed",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from unittest.mock import ANY, Mock, patch

from keycloak import KeycloakPostError
import pytest

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser
Expand All @@ -45,6 +45,7 @@ def test_login_callback(self, mock_get_keycloak_client, mock_get_auth_manager, c
mock_keycloak_client.token.return_value = {
"access_token": "access_token",
"refresh_token": "refresh_token",
"id_token": "id_token",
}
mock_keycloak_client.userinfo.return_value = {
"sub": "sub",
Expand Down Expand Up @@ -73,43 +74,35 @@ def test_login_callback(self, mock_get_keycloak_client, mock_get_auth_manager, c
assert "location" in response.headers
assert "_token" in response.cookies
assert response.cookies["_token"] == token
assert response.cookies["_id_token"] == "id_token"

def test_login_callback_without_code(self, client):
response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login_callback")
assert response.status_code == 400

@pytest.mark.parametrize(
("id_token", "logout_callback_url"),
[
(None, "http://testserver/auth/logout_callback"),
(
"id_token",
"logout_url?post_logout_redirect_uri=http://testserver/auth/logout_callback&id_token_hint=id_token",
),
],
)
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_logout(self, mock_get_keycloak_client, client):
def test_logout(self, mock_get_keycloak_client, id_token, logout_callback_url, client):
mock_keycloak_client = Mock()
mock_keycloak_client.well_known.return_value = {"end_session_endpoint": "logout_url"}
mock_keycloak_client.refresh_token.return_value = {"id_token": "id_token"}
mock_get_keycloak_client.return_value = mock_keycloak_client
response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/logout", follow_redirects=False)
assert response.status_code == 307
assert "location" in response.headers
assert (
response.headers["location"]
== "logout_url?post_logout_redirect_uri=http://testserver/auth/logout_callback&id_token_hint=id_token"
)
mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token")

@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
def test_logout_when_keycloak_client_raises_keycloak_post_error(self, mock_get_keycloak_client, client):
mock_keycloak_client = Mock()
mock_keycloak_client.well_known.return_value = {"end_session_endpoint": "logout_url"}
mock_keycloak_client.refresh_token.side_effect = KeycloakPostError(
response_code=400,
response_body=b'{"error":"invalid_grant","error_description":"Token is not active"}',
response = client.get(
AUTH_MANAGER_FASTAPI_APP_PREFIX + "/logout",
cookies={"_id_token": id_token},
follow_redirects=False,
)
mock_get_keycloak_client.return_value = mock_keycloak_client
response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/logout", follow_redirects=False)
assert response.status_code == 307
assert "location" in response.headers
assert (
response.headers["location"]
== "logout_url?post_logout_redirect_uri=http://testserver/auth/logout_callback"
)
mock_keycloak_client.refresh_token.assert_called_once_with("refresh_token")
assert response.headers["location"] == logout_callback_url

@patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager")
def test_refresh_token(self, mock_get_auth_manager, client):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ def test_create_token_client_credentials_with_invalid_credentials(self, mock_get
with pytest.raises(fastapi.exceptions.HTTPException) as exc_info:
create_client_credentials_token(client_id=test_client_id, client_secret=test_client_secret)

assert exc_info.value.status_code == 401
assert exc_info.value.status_code == 403
assert "Client credentials authentication failed" in exc_info.value.detail