Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/dag_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@

from __future__ import annotations

from pydantic import BaseModel
from pydantic import ConfigDict

from airflow.api_fastapi.core_api.base import BaseModel


class DagTagResponse(BaseModel):
"""DAG Tag serializer for responses."""

model_config = ConfigDict(populate_by_name=True, from_attributes=True)

name: str
dag_id: str


class DAGTagCollectionResponse(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.datamodels.dag_tags import DagTagResponse
from airflow.configuration import conf
from airflow.serialization.pydantic.dag import DagTagPydantic


class DAGResponse(BaseModel):
Expand All @@ -50,7 +50,7 @@ class DAGResponse(BaseModel):
description: str | None
timetable_summary: str | None
timetable_description: str | None
tags: list[DagTagPydantic]
tags: list[DagTagResponse]
max_active_tasks: int
max_active_runs: int | None
max_consecutive_failed_dag_runs: int
Expand Down
13 changes: 6 additions & 7 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6714,7 +6714,7 @@ components:
title: Timetable Description
tags:
items:
$ref: '#/components/schemas/DagTagPydantic'
$ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
Expand Down Expand Up @@ -6936,7 +6936,7 @@ components:
title: Timetable Description
tags:
items:
$ref: '#/components/schemas/DagTagPydantic'
$ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
Expand Down Expand Up @@ -7412,7 +7412,7 @@ components:
title: Timetable Description
tags:
items:
$ref: '#/components/schemas/DagTagPydantic'
$ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
Expand Down Expand Up @@ -7665,7 +7665,7 @@ components:
- count
title: DagStatsStateResponse
description: DagStatsState serializer for responses.
DagTagPydantic:
DagTagResponse:
properties:
name:
type: string
Expand All @@ -7677,9 +7677,8 @@ components:
required:
- name
- dag_id
title: DagTagPydantic
description: Serializable representation of the DagTag ORM SqlAlchemyModel used
by internal API.
title: DagTagResponse
description: DAG Tag serializer for responses.
DagWarningType:
type: string
enum:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TIEnterRunningPayload(BaseModel):

state: Annotated[
Literal[TIState.RUNNING],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
# Specify a default in the schema, but not in code.
WithJsonSchema({"type": "string", "enum": [TIState.RUNNING], "default": TIState.RUNNING}),
]
hostname: str
Expand Down
20 changes: 2 additions & 18 deletions airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
Expand All @@ -74,7 +73,6 @@
from sqlalchemy.orm.session import Session

from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag_run import DagRunPydantic

log = logging.getLogger(__name__)

Expand All @@ -96,7 +94,7 @@ def _fetch_dag_run_from_run_id_or_logical_date_string(
dag_id: str,
value: str,
session: Session,
) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]:
) -> tuple[DagRun, pendulum.DateTime | None]:
"""
Try to find a DAG run with a given string value.

Expand Down Expand Up @@ -132,7 +130,7 @@ def _get_dag_run(
create_if_necessary: CreateIfNecessary,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
) -> tuple[DagRun | DagRunPydantic, bool]:
) -> tuple[DagRun, bool]:
"""
Try to retrieve a DAG run from a string representing either a run ID or logical date.

Expand Down Expand Up @@ -259,8 +257,6 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None | Tas
- as raw task
- by executor
"""
if TYPE_CHECKING:
assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44 implementation to complete
if args.local:
return _run_task_by_local_task_job(args, ti)
if args.raw:
Expand Down Expand Up @@ -497,9 +493,6 @@ def task_failed_deps(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
# tasks_failed-deps is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
# TODO, Do we want to print or log this
Expand All @@ -524,9 +517,6 @@ def task_state(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
# task_state is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
print(ti.current_state())


Expand Down Expand Up @@ -654,9 +644,6 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
ti, dr_created = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db"
)
# task_test is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
try:
with redirect_stdout(RedactedIO()):
if args.dry_run:
Expand Down Expand Up @@ -705,9 +692,6 @@ def task_render(args, dag: DAG | None = None) -> None:
ti, _ = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory"
)
# task_render is executed with access to the database.
if isinstance(ti, TaskInstancePydantic):
raise ValueError("not a TaskInstance")
with create_session() as session, set_current_task_instance_session(session=session):
ti.render_templates()
for attr in task.template_fields:
Expand Down
8 changes: 4 additions & 4 deletions airflow/jobs/JOB_LIFECYCLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ sequenceDiagram
DB --> Internal API: Close Session
deactivate DB

Internal API->>CLI component: JobPydantic object
Internal API->>CLI component: Job object

CLI component->>JobRunner: Create Job Runner
JobRunner ->> CLI component: JobRunner object
Expand All @@ -109,7 +109,7 @@ sequenceDiagram

activate JobRunner

JobRunner->>Internal API: prepare_for_execution [JobPydantic]
JobRunner->>Internal API: prepare_for_execution [Job]

Internal API-->>DB: Create Session
activate DB
Expand All @@ -131,7 +131,7 @@ sequenceDiagram
deactivate DB
Internal API ->> JobRunner: returned data
and
JobRunner->>Internal API: perform_heartbeat <br> [Job Pydantic]
JobRunner->>Internal API: perform_heartbeat <br> [Job]
Internal API-->>DB: Create Session
activate DB
Internal API->>DB: perform_heartbeat [Job]
Expand All @@ -142,7 +142,7 @@ sequenceDiagram
deactivate DB
end

JobRunner->>Internal API: complete_execution <br> [Job Pydantic]
JobRunner->>Internal API: complete_execution <br> [Job]
Internal API-->>DB: Create Session
Internal API->>DB: complete_execution [Job]
activate DB
Expand Down
3 changes: 1 addition & 2 deletions airflow/jobs/base_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sqlalchemy.orm import Session

from airflow.jobs.job import Job
from airflow.serialization.pydantic.job import JobPydantic


class BaseJobRunner:
Expand Down Expand Up @@ -64,7 +63,7 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:

@classmethod
@provide_session
def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | JobPydantic | None:
def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last heartbeat received."""
from airflow.jobs.job import most_recent_job

Expand Down
2 changes: 2 additions & 0 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __getattr__(name):


__lazy_imports = {
"Job": "airflow.jobs.job",
"DAG": "airflow.models.dag",
"ID_LEN": "airflow.models.base",
"Base": "airflow.models.base",
Expand Down Expand Up @@ -112,6 +113,7 @@ def __getattr__(name):
if TYPE_CHECKING:
# I was unable to get mypy to respect a airflow/models/__init__.pyi, so
# having to resort back to this hacky method
from airflow.jobs.job import Job
from airflow.models.base import ID_LEN, Base
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils.context import Context

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -334,7 +333,7 @@ def deserialize(cls, data: dict, dags: dict) -> DagParam:
def process_params(
dag: DAG,
task: Operator,
dag_run: DagRun | DagRunPydantic | None,
dag_run: DagRun | None,
*,
suppress_exception: bool,
) -> dict[str, Any]:
Expand Down
5 changes: 2 additions & 3 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
from airflow.serialization.pydantic.dag_run import DagRunPydantic

# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
Expand All @@ -61,7 +60,7 @@ class SkipMixin(LoggingMixin):

@staticmethod
def _set_state_to_skipped(
dag_run: DagRun | DagRunPydantic,
dag_run: DagRun,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
Expand Down Expand Up @@ -95,7 +94,7 @@ def _set_state_to_skipped(
@provide_session
def skip(
self,
dag_run: DagRun | DagRunPydantic,
dag_run: DagRun,
tasks: Iterable[DAGNode],
map_index: int = -1,
session: Session = NEW_SESSION,
Expand Down
8 changes: 3 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.serialization.pydantic.asset import AssetEventPydantic
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -984,7 +982,7 @@ def get_prev_end_date_success() -> pendulum.DateTime | None:
return None
return timezone.coerce_datetime(dagrun.end_date)

def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]:
def get_triggering_events() -> dict[str, list[AssetEvent]]:
if TYPE_CHECKING:
assert session is not None

Expand All @@ -995,7 +993,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]:
if dag_run not in session:
dag_run = session.merge(dag_run, load=False)
asset_events = dag_run.consumed_asset_events
triggering_events: dict[str, list[AssetEvent | AssetEventPydantic]] = defaultdict(list)
triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
for event in asset_events:
if event.asset:
triggering_events[event.asset.uri].append(event)
Expand Down Expand Up @@ -1890,7 +1888,7 @@ def _command_as_list(
pool: str | None = None,
cfg_path: str | None = None,
) -> list[str]:
dag: DAG | DagModel | DagModelPydantic | None
dag: DAG | DagModel | None
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
if TYPE_CHECKING:
Expand Down
16 changes: 0 additions & 16 deletions airflow/serialization/pydantic/__init__.py

This file was deleted.

Loading