Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,9 @@ class TIRuntimeCheckPayload(StrictBaseModel):

inlets: list[AssetProfile] | None = None
outlets: list[AssetProfile] | None = None


class TaskStatesResponse(BaseModel):
"""Response for task states with run_id, task and state."""

task_states: dict[str, Any]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import json
import logging
from typing import Annotated
from collections import defaultdict
from typing import Annotated, Any
from uuid import UUID

from cadwyn import VersionedAPIRouter
Expand All @@ -34,6 +35,7 @@
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
PrevSuccessfulDagRunResponse,
TaskStatesResponse,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
Expand Down Expand Up @@ -607,36 +609,7 @@ def get_count(
query = query.where(TI.run_id.in_(run_ids))

if task_group_id:
# Get all tasks in the task group
dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session)
if not dag:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"DAG {dag_id} not found",
},
)

task_group = dag.task_group_dict.get(task_group_id)
if not task_group:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Task group {task_group_id} not found in DAG {dag_id}",
},
)

# First get all task instances to get the task_id, map_index pairs
group_tasks = session.scalars(
select(TI).where(
TI.dag_id == dag_id,
TI.task_id.in_(task.task_id for task in task_group.iter_tasks()),
*([TI.logical_date.in_(logical_dates)] if logical_dates else []),
*([TI.run_id.in_(run_ids)] if run_ids else []),
)
).all()
group_tasks = _get_group_tasks(dag_id, task_group_id, session, logical_dates, run_ids)

# Get unique (task_id, map_index) pairs
task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks]
Expand All @@ -659,9 +632,45 @@ def get_count(
query = query.where(TI.state.in_(states))

count = session.scalar(query)

return count or 0


@router.get("/states", status_code=status.HTTP_200_OK)
def get_task_states(
dag_id: str,
session: SessionDep,
task_ids: Annotated[list[str] | None, Query()] = None,
task_group_id: Annotated[str | None, Query()] = None,
logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
run_ids: Annotated[list[str] | None, Query()] = None,
) -> TaskStatesResponse:
"""Get the task states for the given criteria."""
run_id_task_state_map: dict[str, dict[str, Any]] = defaultdict(dict)

query = select(TI).where(TI.dag_id == dag_id)

if task_ids:
query = query.where(TI.task_id.in_(task_ids))

if logical_dates:
query = query.where(TI.logical_date.in_(logical_dates))

if run_ids:
query = query.where(TI.run_id.in_(run_ids))

results = session.scalars(query).all()

[run_id_task_state_map[task.run_id].update({task.task_id: task.state}) for task in results]

if task_group_id:
group_tasks = _get_group_tasks(dag_id, task_group_id, session, logical_dates, run_ids)

[run_id_task_state_map[task.run_id].update({task.task_id: task.state}) for task in group_tasks]

return TaskStatesResponse(task_states=run_id_task_state_map)


@ti_id_router.only_exists_in_older_versions
@ti_id_router.post(
"/{task_instance_id}/runtime-checks",
Expand Down Expand Up @@ -702,5 +711,40 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
return max_tries != 0 and try_number <= max_tries


def _get_group_tasks(dag_id: str, task_group_id: str, session: SessionDep, logical_dates=None, run_ids=None):
# Get all tasks in the task group
dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session)
if not dag:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"DAG {dag_id} not found",
},
)

task_group = dag.task_group_dict.get(task_group_id)
if not task_group:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Task group {task_group_id} not found in DAG {dag_id}",
},
)

# First get all task instances to get the task_id, map_index pairs
group_tasks = session.scalars(
select(TI).where(
TI.dag_id == dag_id,
TI.task_id.in_(task.task_id for task in task_group.iter_tasks()),
*([TI.logical_date.in_(logical_dates)] if logical_dates else []),
*([TI.run_id.in_(run_ids)] if run_ids else []),
)
).all()

return group_tasks


# This line should be at the end of the file to ensure all routes are registered
router.include_router(ti_id_router)
24 changes: 23 additions & 1 deletion airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@
GetConnection,
GetDagRunState,
GetDRCount,
GetTaskStates,
GetTICount,
GetVariable,
GetXCom,
TaskStatesResult,
TICount,
VariableResult,
XComResult,
Expand Down Expand Up @@ -225,6 +227,7 @@ class TriggerStateSync(BaseModel):
DagRunStateResult,
DRCount,
TICount,
TaskStatesResult,
ErrorResponse,
],
Field(discriminator="type"),
Expand All @@ -242,6 +245,7 @@ class TriggerStateSync(BaseModel):
GetVariable,
GetXCom,
GetTICount,
GetTaskStates,
GetDagRunState,
GetDRCount,
],
Expand Down Expand Up @@ -360,7 +364,12 @@ def client(self) -> Client:
return client

def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -> None: # type: ignore[override]
from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse, XComResponse
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TaskStatesResponse,
VariableResponse,
XComResponse,
)

resp: BaseModel | None = None
dump_opts = {}
Expand Down Expand Up @@ -435,6 +444,19 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger) -
run_ids=msg.run_ids,
states=msg.states,
)

elif isinstance(msg, GetTaskStates):
run_id_task_state_map = self.client.task_instances.get_task_states(
dag_id=msg.dag_id,
task_ids=msg.task_ids,
task_group_id=msg.task_group_id,
logical_dates=msg.logical_dates,
run_ids=msg.run_ids,
)
if isinstance(run_id_task_state_map, TaskStatesResponse):
resp = TaskStatesResult.from_api_response(run_id_task_state_map)
else:
resp = run_id_task_state_map
else:
raise ValueError(f"Unknown message type {type(msg)}")

Expand Down
Loading