Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

if TYPE_CHECKING:
import httpx
from fastapi.routing import APIRoute

import structlog

Expand Down Expand Up @@ -113,6 +114,10 @@ def customize_openapi(self, openapi_schema: dict[str, Any]) -> dict[str, Any]:
This is particularly useful for client SDKs that require models for types
not directly exposed in any endpoint's request or response schema.

We also replace ``anyOf`` with ``oneOf`` in the API spec as this produces better results for the code
generators. This is because anyOf can technically be more than of the given schemas, but 99.9% of the
time (perhaps 100% in this API) the types are mutually exclusive, so oneOf is more correct

References:
- https://fastapi.tiangolo.com/how-to/extending-openapi/#modify-the-openapi-schema
"""
Expand All @@ -124,11 +129,23 @@ def customize_openapi(self, openapi_schema: dict[str, Any]) -> dict[str, Any]:
# The `JsonValue` component is missing any info. causes issues when generating models
openapi_schema["components"]["schemas"]["JsonValue"] = {
"title": "Any valid JSON value",
"anyOf": [
"oneOf": [
{"type": t} for t in ("string", "number", "integer", "object", "array", "boolean", "null")
],
}

def replace_any_of_with_one_of(spec):
if isinstance(spec, dict):
return {
("oneOf" if key == "anyOf" else key): replace_any_of_with_one_of(value)
for key, value in spec.items()
}
if isinstance(spec, list):
return [replace_any_of_with_one_of(item) for item in spec]
return spec

openapi_schema = replace_any_of_with_one_of(openapi_schema)

for comp in openapi_schema["components"]["schemas"].values():
for prop in comp.get("properties", {}).values():
# {"type": "string", "const": "deferred"}
Expand All @@ -147,11 +164,16 @@ def create_task_execution_api_app() -> FastAPI:
from airflow.api_fastapi.execution_api.routes import execution_api_router
from airflow.api_fastapi.execution_api.versions import bundle

def custom_generate_unique_id(route: APIRoute):
# This is called only if the route doesn't provide an explicit operation ID
return route.name

# See https://docs.cadwyn.dev/concepts/version_changes/ for info about API versions
app = CadwynWithOpenAPICustomization(
title="Airflow Task Execution API",
description="The private Airflow Task Execution API.",
lifespan=lifespan,
generate_unique_id_function=custom_generate_unique_id,
api_version_parameter_name="Airflow-API-Version",
api_version_default_value=bundle.versions[0].value,
versions=bundle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

from typing import Annotated

from fastapi import HTTPException, Query, status
from fastapi import APIRouter, HTTPException, Query, status
from sqlalchemy import and_, select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
from airflow.api_fastapi.execution_api.datamodels.asset_event import (
AssetEventResponse,
Expand All @@ -32,7 +31,7 @@
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel

# TODO: Add dependency on JWT token
router = AirflowRouter(
router = APIRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@

from typing import Annotated

from fastapi import HTTPException, Query, status
from fastapi import APIRouter, HTTPException, Query, status
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
from airflow.models.asset import AssetModel

# TODO: Add dependency on JWT token
router = AirflowRouter(
router = APIRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
import logging
from typing import Annotated

from fastapi import HTTPException, Query, status
from fastapi import APIRouter, HTTPException, Query, status
from sqlalchemy import func, select

from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload
from airflow.exceptions import DagRunAlreadyExists
Expand All @@ -34,7 +33,7 @@
from airflow.models.dagrun import DagRun
from airflow.utils.types import DagRunTriggeredByType

router = AirflowRouter()
router = APIRouter()


log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

from __future__ import annotations

from fastapi import APIRouter
from fastapi.responses import JSONResponse

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.deps import DepContainer

router = AirflowRouter()
router = APIRouter()


@router.get("")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def get_previous_successful_dagrun(


@router.get("/count", status_code=status.HTTP_200_OK)
def get_count(
def get_task_instance_count(
dag_id: str,
session: SessionDep,
task_ids: Annotated[list[str] | None, Query()] = None,
Expand Down Expand Up @@ -640,15 +640,15 @@ def get_count(


@router.get("/states", status_code=status.HTTP_200_OK)
def get_task_states(
def get_task_instance_states(
dag_id: str,
session: SessionDep,
task_ids: Annotated[list[str] | None, Query()] = None,
task_group_id: Annotated[str | None, Query()] = None,
logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
run_ids: Annotated[list[str] | None, Query()] = None,
) -> TaskStatesResponse:
"""Get the task states for the given criteria."""
"""Get the states for Task Instances with the given criteria."""
run_id_task_state_map: dict[str, dict[str, Any]] = defaultdict(dict)

query = select(TI).where(TI.dag_id == dag_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@

from uuid import UUID

from fastapi import status
from fastapi import APIRouter, status
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.models.taskreschedule import TaskReschedule

router = AirflowRouter(
router = APIRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
import sys
from typing import Annotated, Any

from fastapi import Body, Depends, HTTPException, Path, Query, Request, Response, status
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, status
from pydantic import BaseModel, JsonValue
from sqlalchemy import delete
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -57,7 +56,7 @@ async def has_xcom_access(
return True


router = AirflowRouter(
router = APIRouter(
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the XCom"},
Expand Down Expand Up @@ -95,7 +94,7 @@ async def xcom_query(
"description": "Metadata about the number of matching XCom values",
"headers": {
"Content-Range": {
"pattern": r"^map_indexes \d+$",
"schema": {"pattern": r"^map_indexes \d+$"},
"description": "The number of (mapped) XCom values found for this task.",
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_custom_openapi_includes_extra_schemas(client):
assert "TaskInstance" in openapi_schema["components"]["schemas"]
schema = openapi_schema["components"]["schemas"]["TaskInstance"]

assert schema == TaskInstance.model_json_schema()
assert schema["properties"].keys() == TaskInstance.model_json_schema()["properties"].keys()


def test_access_api_contract(client):
Expand Down
4 changes: 4 additions & 0 deletions task-sdk/dev/generate_task_sdk_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
import os
import sys
from pathlib import Path
Expand All @@ -29,6 +30,7 @@
PythonVersion,
generate as generate_models,
)
from openapi_spec_validator import validate_spec

os.environ["_AIRFLOW__AS_LIBRARY"] = "1"

Expand Down Expand Up @@ -84,6 +86,8 @@ def generate_file():
client.get(f"http://localhost/openapi.json?version={latest_version}").raise_for_status().text
)

validate_spec(json.loads(openapi_schema))

os.chdir(AIRFLOW_TASK_SDK_ROOT_PATH)

args = load_config()
Expand Down
1 change: 1 addition & 0 deletions task-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ exclude_also = [
[dependency-groups]
codegen = [
"datamodel-code-generator[http]==0.28.2",
"openapi-spec-validator>=0.7.1",
"svcs>=25.1.0",
"ruff==0.11.2",
"rich>=12.4.4",
Expand Down
Loading