Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import inspect
from collections import abc
from collections.abc import Iterable
from datetime import datetime, timedelta
from typing import Any

Expand Down Expand Up @@ -151,9 +150,9 @@ class DAGDetailsResponse(DAGResponse):
start_date: datetime | None
end_date: datetime | None
is_paused_upon_creation: bool | None
params: abc.MutableMapping | None
params: abc.Mapping | None
render_template_as_native_obj: bool
template_search_path: Iterable[str] | None
template_search_path: list[str] | None
timezone: str | None
last_parsed: datetime | None
default_args: abc.Mapping | None
Expand All @@ -177,7 +176,7 @@ def get_doc_md(cls, doc_md: str | None) -> str | None:

@field_validator("params", mode="before")
@classmethod
def get_params(cls, params: abc.MutableMapping | None) -> dict | None:
def get_params(cls, params: abc.Mapping | None) -> dict | None:
"""Convert params attribute to dict representation."""
if params is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

from __future__ import annotations

from typing import TYPE_CHECKING, cast

from fastapi import Depends, HTTPException, status
from sqlalchemy.sql import select

Expand All @@ -31,10 +29,6 @@
from airflow.exceptions import TaskNotFound
from airflow.models import DagRun

if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator

extra_links_router = AirflowRouter(
tags=["Extra Links"], prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/links"
)
Expand Down Expand Up @@ -62,8 +56,7 @@ def get_extra_links(
dag = get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, session)

try:
# TODO (GH-52141): Make dag a db-backed object so it only returns db-backed tasks.
task = cast("MappedOperator | SerializedBaseOperator", dag.get_task(task_id))
task = dag.get_task(task_id)
except TaskNotFound:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task with ID = {task_id} not found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,7 @@ def ti_run(

if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session):
upstream_map_indexes = dict(
_get_upstream_map_indexes(
# TODO (GH-52141): This get_task should return scheduler
# types instead, but currently it inherits SDK's DAG.
cast("MappedOperator | SerializedBaseOperator", dag.get_task(ti.task_id)),
ti.map_index,
ti.run_id,
session=session,
)
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session=session)
)
else:
upstream_map_indexes = None
Expand Down
10 changes: 3 additions & 7 deletions airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _generate_temporary_run_id() -> str:

def _get_dag_run(
*,
dag: DAG,
dag: DAG | SerializedDAG,
create_if_necessary: CreateIfNecessary,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
Expand Down Expand Up @@ -274,9 +274,7 @@ def task_state(args) -> None:
"""
if not (dag := SerializedDagModel.get_dag(args.dag_id)):
raise SystemExit(f"Can not find dag {args.dag_id!r}")
# TODO (GH-52141): get_task in scheduler needs to return scheduler types
# instead, but currently it inherits SDK's DAG.
task = cast("Operator", dag.get_task(task_id=args.task_id))
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
print(ti.state)

Expand Down Expand Up @@ -434,9 +432,7 @@ def task_render(args, dag: DAG | None = None) -> None:
dag = get_bagged_dag(args.bundle_name, args.dag_id)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
ti, _ = _get_ti(
# TODO (GH-52141): get_task in scheduler needs to return scheduler types
# instead, but currently it inherits SDK's DAG.
cast("Operator", serialized_dag.get_task(task_id=args.task_id)),
serialized_dag.get_task(task_id=args.task_id),
args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id,
create_if_necessary="memory",
Expand Down
9 changes: 3 additions & 6 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from datetime import date, datetime, timedelta
from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from sqlalchemy import and_, delete, desc, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
Expand Down Expand Up @@ -97,9 +97,8 @@
from airflow._shared.logging.types import Logger
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_utils import ExecutorName
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.sqlalchemy import CommitProhibitorGuard

TI = TaskInstance
Expand Down Expand Up @@ -917,9 +916,7 @@ def process_executor_events(
)
if TYPE_CHECKING:
assert dag
# TODO (GH-52141): get_task in scheduler needs to return scheduler types
# instead, but currently it inherits SDK's DAG.
task = cast("MappedOperator | SerializedBaseOperator", dag.get_task(ti.task_id))
task = dag.get_task(ti.task_id)
except Exception:
cls.logger().exception("Marking task instance %s as %s", ti, state)
ti.set_state(state)
Expand Down
14 changes: 8 additions & 6 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ def recalculate(self) -> _UnfinishedStates:
self.notify_dagrun_state_changed(msg="task_failure")

if execute_callbacks and dag.has_on_failure_callback:
self.handle_dag_callback(dag=dag, success=False, reason="task_failure")
self.handle_dag_callback(dag=cast("SDKDAG", dag), success=False, reason="task_failure")
elif dag.has_on_failure_callback:
callback = DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def recalculate(self) -> _UnfinishedStates:
self.notify_dagrun_state_changed(msg="success")

if execute_callbacks and dag.has_on_success_callback:
self.handle_dag_callback(dag=dag, success=True, reason="success")
self.handle_dag_callback(dag=cast("SDKDAG", dag), success=True, reason="success")
elif dag.has_on_success_callback:
callback = DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
Expand Down Expand Up @@ -1237,7 +1237,11 @@ def recalculate(self) -> _UnfinishedStates:
self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")

if execute_callbacks and dag.has_on_failure_callback:
self.handle_dag_callback(dag=dag, success=False, reason="all_tasks_deadlocked")
self.handle_dag_callback(
dag=cast("SDKDAG", dag),
success=False,
reason="all_tasks_deadlocked",
)
elif dag.has_on_failure_callback:
callback = DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
Expand Down Expand Up @@ -1306,9 +1310,7 @@ def _filter_tis_and_exclude_removed(dag: SerializedDAG, tis: list[TI]) -> Iterab
"""Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""
for ti in tis:
try:
# TODO (GH-52141): get_task in scheduler needs to return scheduler types
# instead, but currently it inherits SDK's DAG.
ti.task = cast("Operator", dag.get_task(ti.task_id))
ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
if ti.state != TaskInstanceState.REMOVED:
self.log.error("Failed to get task for ti %s. Marking it as removed.", ti)
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class MappedOperator(DAGNode):
dag: SerializedDAG = attrs.field(init=False) # type: ignore[assignment]
task_group: SerializedTaskGroup = attrs.field(init=False) # type: ignore[assignment]

doc: str | None = attrs.field(init=False)
doc_json: str | None = attrs.field(init=False)
doc_rst: str | None = attrs.field(init=False)
doc_yaml: str | None = attrs.field(init=False)
start_date: pendulum.DateTime | None = attrs.field(init=False, default=None)
end_date: pendulum.DateTime | None = attrs.field(init=False, default=None)
upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
Expand Down
13 changes: 7 additions & 6 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from collections.abc import Collection, Iterable
from datetime import timedelta
from functools import cache
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
from urllib.parse import quote

import attrs
Expand Down Expand Up @@ -237,8 +237,7 @@ def clear_task_instances(
log.warning("No serialized dag found for dag '%s'", dr.dag_id)
task_id = ti.task_id
if ti_dag and ti_dag.has_task(task_id):
# TODO (GH-52141): Make dag a db-backed object so it only returns db-backed tasks.
task = cast("Operator", ti_dag.get_task(task_id))
task = ti_dag.get_task(task_id)
ti.refresh_from_task(task)
if TYPE_CHECKING:
assert ti.task
Expand Down Expand Up @@ -1455,9 +1454,11 @@ def run(
assert original_task is not None
assert original_task.dag is not None

self.task = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag)).task_dict[
original_task.task_id
]
# We don't set up all tests well...
if not isinstance(original_task.dag, SerializedDAG):
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag))
self.task = serialized_dag.get_task(original_task.task_id)

res = self.check_and_change_state_before_execution(
verbose=verbose,
ignore_all_deps=ignore_all_deps,
Expand Down
5 changes: 2 additions & 3 deletions airflow-core/src/airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from collections.abc import Iterator, Sequence
from functools import singledispatch
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, TypeAlias

import attrs
from sqlalchemy import func, or_, select
Expand Down Expand Up @@ -92,8 +92,7 @@ class SchedulerPlainXComArg(SchedulerXComArg):

@classmethod
def _deserialize(cls, data: dict[str, Any], dag: SerializedDAG) -> Self:
# TODO (GH-52141): SerializedDAG should return scheduler operator instead.
return cls(cast("Operator", dag.get_task(data["task_id"])), data["key"])
return cls(dag.get_task(data["task_id"]), data["key"])

def iter_references(self) -> Iterator[tuple[Operator, str]]:
yield self.operator, self.key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class SerializedTaskGroup(DAGNode):
group_display_name: str | None = attrs.field()
prefix_group_id: bool = attrs.field()
parent_group: SerializedTaskGroup | None = attrs.field()
dag: SerializedDAG = attrs.field()
# TODO (GH-52141): Replace DAGNode dependency.
dag: SerializedDAG = attrs.field() # type: ignore[assignment]
tooltip: str = attrs.field()
default_args: dict[str, Any] = attrs.field(factory=dict)

Expand Down
Loading
Loading