Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from airflow.exceptions import ParamValidationError
from airflow.listeners.listener import get_listener_manager
from airflow.models import DAG, DagModel, DagRun
from airflow.models.dag_version import DagVersion
from airflow.timetables.base import DataInterval
from airflow.utils import timezone
from airflow.utils.state import DagRunState
Expand Down Expand Up @@ -393,7 +392,6 @@ def trigger_dag_run(
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.REST_API,
external_trigger=True,
dag_version=DagVersion.get_latest_version(dag.dag_id),
state=DagRunState.QUEUED,
session=session,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class TaskInstance(StrictBaseModel):
"""Schema for TaskInstance model with minimal required fields needed for Runtime."""

id: uuid.UUID

dag_version_id: uuid.UUID
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this during runtime though? Not seeing where it is used.

task_id: str
dag_id: str
run_id: str
Expand Down
6 changes: 5 additions & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models import TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskReturnCode
from airflow.sdk.definitions.param import ParamsDict
Expand Down Expand Up @@ -211,7 +212,10 @@ def _get_ti(
f"run_id or logical_date of {logical_date_or_run_id!r} not found"
)
# TODO: Validate map_index is in range?
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index, dag_version_id=dag_version.id)
Comment on lines +215 to +218
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this pattern is used so often, maybe it’s better for get_latest_version to only return DagVersion and raise an exception if one cannot be found.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, maybe TI init can handle this for folks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jedcunningham, do you mean, if it's not provided, use the latest dag version ID in init? The only thing is that session is not provided in init.

if dag_run in session:
session.add(ti)
ti.dag_run = dag_run
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TaskInstance(BaseModel):
"""Schema for TaskInstance with minimal required fields needed for Executors and Task SDK."""

id: uuid.UUID

dag_version_id: uuid.UUID
task_id: str
dag_id: str
run_id: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def upgrade():
batch_op.add_column(sa.Column("created_at", UtcDateTime(), nullable=False, default=timezone.utcnow))

with op.batch_alter_table("task_instance", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False)))
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False), nullable=False))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to old tis created before this column is added? I think we need some data migration work here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Yeah, quite needed

batch_op.create_foreign_key(
batch_op.f("task_instance_dag_version_id_fkey"),
"dag_version",
Expand Down
9 changes: 7 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import datetime
import uuid
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -153,7 +154,9 @@ def priority_weight_total(self) -> int:
)
)

def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
def expand_mapped_task(
self, run_id: str, *, dag_version_id: uuid.UUID, session: Session
) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.

Expand Down Expand Up @@ -262,7 +265,9 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
ti = TaskInstance(
self, run_id=run_id, map_index=index, state=state, dag_version_id=dag_version_id
)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
11 changes: 9 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
NotMapped,
)
from airflow.models.base import _sentinel
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
Expand Down Expand Up @@ -623,7 +624,10 @@ def run(
DagRun.logical_date == info.logical_date,
)
).one()
ti = TaskInstance(self, run_id=dag_run.run_id)
dag_version = DagVersion.get_latest_version(self.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(self, run_id=dag_run.run_id, dag_version_id=dag_version.id)
except NoResultFound:
# This is _mostly_ only used in tests
dr = DagRun(
Expand All @@ -640,7 +644,10 @@ def run(
triggered_by=DagRunTriggeredByType.TEST,
state=DagRunState.RUNNING,
)
ti = TaskInstance(self, run_id=dr.run_id)
dag_version = DagVersion.get_latest_version(self.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(self, run_id=dr.run_id, dag_version_id=dag_version.id)
ti.dag_run = dr
session.add(dr)
session.flush()
Expand Down
28 changes: 20 additions & 8 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import itertools
import os
import uuid
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar, overload
Expand Down Expand Up @@ -1208,7 +1209,9 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
# the db references.
ti.clear_db_references(session=session)
try:
expanded_tis, _ = TaskMap.expand_mapped_task(ti.task, self.run_id, session=session)
expanded_tis, _ = TaskMap.expand_mapped_task(
ti.dag_version_id, ti.task, self.run_id, session=session
)
except NotMapped: # Not a mapped task, nothing needed.
return None
if expanded_tis:
Expand Down Expand Up @@ -1242,7 +1245,11 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
# It's enough to revise map index once per task id,
# checking the map index for each mapped task significantly slows down scheduling
if schedulable.task.task_id not in revised_map_index_task_ids:
ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
ready_tis.extend(
self._revise_map_indexes_if_mapped(
schedulable.dag_version_id, schedulable.task, session=session
)
)
revised_map_index_task_ids.add(schedulable.task.task_id)
ready_tis.append(schedulable)

Expand Down Expand Up @@ -1346,7 +1353,10 @@ def _emit_duration_stats_for_finished_state(self):

@provide_session
def verify_integrity(
self, *, session: Session = NEW_SESSION, dag_version_id: UUIDType | None = None
self,
*,
dag_version_id: UUIDType,
session: Session = NEW_SESSION,
) -> None:
"""
Verify the DagRun by checking for removed tasks or tasks that are not in the database yet.
Expand Down Expand Up @@ -1477,7 +1487,7 @@ def _get_task_creator(
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[True],
dag_version_id: UUIDType | None,
dag_version_id: uuid.UUID,
) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]: ...

@overload
Expand All @@ -1486,15 +1496,15 @@ def _get_task_creator(
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[False],
dag_version_id: UUIDType | None,
dag_version_id: uuid.UUID,
) -> Callable[[Operator, Iterable[int]], Iterator[TI]]: ...

def _get_task_creator(
self,
created_counts: dict[str, int],
ti_mutation_hook: Callable,
hook_is_noop: Literal[True, False],
dag_version_id: UUIDType | None,
dag_version_id: uuid.UUID,
) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]:
"""
Get the task creator function.
Expand Down Expand Up @@ -1605,7 +1615,9 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]:
def _revise_map_indexes_if_mapped(
self, dag_version_id: uuid.UUID, task: Operator, *, session: Session
) -> Iterator[TI]:
"""
Check if task increased or reduced in length and handle appropriately.

Expand Down Expand Up @@ -1651,7 +1663,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for index in range(total_length):
if index in existing_indexes:
continue
ti = TI(task, run_id=self.run_id, map_index=index, state=None)
ti = TI(task, run_id=self.run_id, map_index=index, state=None, dag_version_id=dag_version_id)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
14 changes: 10 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os
import signal
import traceback
import uuid
from collections import defaultdict
from collections.abc import Collection, Generator, Iterable, Mapping
from datetime import timedelta
Expand Down Expand Up @@ -1691,7 +1692,9 @@ class TaskInstance(Base, LoggingMixin):
next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))

_task_display_property_value = Column("task_display_name", String(2000), nullable=True)
dag_version_id = Column(UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"))
dag_version_id = Column(
UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False
)
dag_version = relationship("DagVersion", back_populates="task_instances")
# If adding new fields here then remember to add them to
# _set_ti_attrs() or they won't display in the UI correctly
Expand Down Expand Up @@ -1757,10 +1760,11 @@ class TaskInstance(Base, LoggingMixin):
def __init__(
self,
task: Operator,
*,
dag_version_id: uuid.UUID,
run_id: str | None = None,
state: str | None = None,
map_index: int = -1,
dag_version_id: UUIDType | None = None,
):
super().__init__()
self.dag_id = task.dag_id
Expand Down Expand Up @@ -1799,15 +1803,15 @@ def stats_tags(self) -> dict[str, str]:

@staticmethod
def insert_mapping(
run_id: str, task: Operator, map_index: int, dag_version_id: UUIDType | None
run_id: str, task: Operator, map_index: int, dag_version_id: uuid.UUID
) -> dict[str, Any]:
"""
Insert mapping.

:meta private:
"""
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index)
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
)

return {
Expand Down Expand Up @@ -1854,6 +1858,7 @@ def from_runtime_ti(cls, runtime_ti: RuntimeTaskInstanceProtocol) -> TaskInstanc
run_id=runtime_ti.run_id,
task=runtime_ti.task, # type: ignore[arg-type]
map_index=runtime_ti.map_index,
dag_version_id=runtime_ti.dag_version_id,
)

if TYPE_CHECKING:
Expand All @@ -1869,6 +1874,7 @@ def to_runtime_ti(self, context_from_server) -> RuntimeTaskInstanceProtocol:
task_id=self.task_id,
dag_id=self.dag_id,
run_id=self.run_id,
dag_version_id=self.dag_version_id,
try_numer=self.try_number,
map_index=self.map_index,
task=self.task,
Expand Down
10 changes: 8 additions & 2 deletions airflow/models/taskmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import collections.abc
import enum
import uuid
from collections.abc import Collection, Iterable, Sequence
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -121,7 +122,9 @@ def variant(self) -> TaskMapVariant:
return TaskMapVariant.DICT

@classmethod
def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
def expand_mapped_task(
cls, dag_version_id: uuid.UUID, task, run_id: str, *, session: Session
) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.

Expand Down Expand Up @@ -224,7 +227,10 @@ def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Seq

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(task, run_id=run_id, map_index=index, state=state)

ti = TaskInstance(
task, run_id=run_id, map_index=index, state=state, dag_version_id=dag_version_id
)
task.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
617116c74735faa69a297fe665664b691f341487bc8dcbef3e3ec7e76cdea799
9cc3230fc60a08ab5be6779b9fec70089bf2f9eeeb120a4f746628e2e7582989
3 changes: 2 additions & 1 deletion docs/apache-airflow/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import os
import sys
import uuid
from collections.abc import Iterable, Mapping
from datetime import datetime, timezone
from io import FileIO
Expand Down Expand Up @@ -92,6 +93,7 @@ class RuntimeTaskInstance(TaskInstance):
model_config = ConfigDict(arbitrary_types_allowed=True)

task: BaseOperator
dag_version_id: uuid.UUID
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with these in task_sdk?

_ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
"""The Task Instance context from the API server, if any."""

Expand Down
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Protocol, Union

Expand Down Expand Up @@ -56,6 +57,7 @@ class RuntimeTaskInstanceProtocol(Protocol):
"""Minimal interface for a task instance available during the execution."""

task: BaseOperator
dag_version_id: uuid.UUID
task_id: str
dag_id: str
run_id: str
Expand Down
Loading
Loading