Skip to content

Commit

Permalink
AIP-65: Track the serialized DAG across DagRun & TaskInstance
Browse files Browse the repository at this point in the history
This helps to track the serialized DAG version the task instance ran
with by establishing a relationship between the entities instead of using
the dag_hash.
  • Loading branch information
ephraimbuddy committed Oct 3, 2024
1 parent d769be2 commit b1ce903
Show file tree
Hide file tree
Showing 15 changed files with 1,828 additions and 1,706 deletions.
15 changes: 8 additions & 7 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
continue

dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
serialized_dag = SerializedDagModel.get(dag.dag_id, session=session)

data_interval = dag.get_next_data_interval(dag_model)
# Explicitly check if the DagRun already exists. This is an edge case
Expand All @@ -1338,7 +1338,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
data_interval=data_interval,
external_trigger=False,
session=session,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=self.job.id,
triggered_by=DagRunTriggeredByType.TIMETABLE,
)
Expand Down Expand Up @@ -1397,7 +1397,7 @@ def _create_dag_runs_asset_triggered(
)
continue

dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
serialized_dag = SerializedDagModel.get(dag.dag_id, session=session)

# Explicitly check if the DagRun already exists. This is an edge case
# where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after`
Expand Down Expand Up @@ -1452,7 +1452,7 @@ def _create_dag_runs_asset_triggered(
state=DagRunState.QUEUED,
external_trigger=False,
session=session,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=self.job.id,
triggered_by=DagRunTriggeredByType.DATASET,
)
Expand Down Expand Up @@ -1701,12 +1701,13 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->
Return True if we determine that DAG still exists.
"""
latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session)
if dag_run.dag_hash == latest_version:
latest_version = SerializedDagModel.get(dag_run.dag_id, session=session)

if latest_version and dag_run.serialized_dag_id == latest_version.id:
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True

dag_run.dag_hash = latest_version
dag_run.serialized_dag = latest_version

# Refresh the DAG
dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add SDM foreignkey to DagRun.
Revision ID: 4235395d5ec5
Revises: e1ff90d3efe9
Create Date: 2024-10-03 13:37:55.678831
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "4235395d5ec5"
down_revision = "e1ff90d3efe9"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply Add SDM foreignkey to DagRun."""
with op.batch_alter_table("dag_run") as batch_op:
batch_op.add_column(sa.Column("serialized_dag_id", sa.Integer()))
batch_op.create_foreign_key(
"dag_run_serialized_dag_fkey",
"serialized_dag",
["serialized_dag_id"],
["id"],
ondelete="SET NULL",
)
batch_op.drop_column("dag_hash")

with op.batch_alter_table("task_instance") as batch_op:
batch_op.add_column(sa.Column("serialized_dag_id", sa.Integer()))
batch_op.create_foreign_key(
"task_instance_serialized_dag_fkey",
"serialized_dag",
["serialized_dag_id"],
["id"],
ondelete="SET NULL",
)

with op.batch_alter_table("task_instance_history") as batch_op:
batch_op.add_column(sa.Column("serialized_dag_id", sa.Integer()))


def downgrade():
"""Unapply Add SDM foreignkey to DagRun."""
with op.batch_alter_table("dag_run") as batch_op:
batch_op.add_column(sa.Column("dag_hash", sa.String(32)))
batch_op.drop_constraint("dag_run_serialized_dag_fkey", type_="foreignkey")
batch_op.drop_column("serialized_dag_id")

with op.batch_alter_table("task_instance") as batch_op:
batch_op.drop_constraint("task_instance_serialized_dag_fkey", type_="foreignkey")
batch_op.drop_column("serialized_dag_id")

with op.batch_alter_table("task_instance_history") as batch_op:
batch_op.drop_column("serialized_dag_id")
10 changes: 5 additions & 5 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _create_orm_dagrun(
conf,
state,
run_type,
dag_hash,
serialized_dag,
creating_job_id,
data_interval,
session,
Expand All @@ -317,7 +317,7 @@ def _create_orm_dagrun(
conf=conf,
state=state,
run_type=run_type,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=creating_job_id,
data_interval=data_interval,
triggered_by=triggered_by,
Expand Down Expand Up @@ -2542,7 +2542,7 @@ def create_dagrun(
conf: dict | None = None,
run_type: DagRunType | None = None,
session: Session = NEW_SESSION,
dag_hash: str | None = None,
serialized_dag: SerializedDagModel | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
):
Expand All @@ -2561,7 +2561,7 @@ def create_dagrun(
:param conf: Dict containing configuration/parameters to pass to the DAG
:param creating_job_id: id of the job creating this DagRun
:param session: database session
:param dag_hash: Hash of Serialized DAG
:param serialized_dag: The serialized Dag Model object
:param data_interval: Data interval of the DagRun
"""
logical_date = timezone.coerce_datetime(execution_date)
Expand Down Expand Up @@ -2627,7 +2627,7 @@ def create_dagrun(
conf=conf,
state=state,
run_type=run_type,
dag_hash=dag_hash,
serialized_dag=serialized_dag,
creating_job_id=creating_job_id,
data_interval=data_interval,
session=session,
Expand Down
19 changes: 16 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
Expand Down Expand Up @@ -141,7 +142,6 @@ class DagRun(Base, LoggingMixin):
data_interval_end = Column(UtcDateTime)
# When a scheduler last attempted to schedule TIs for this DagRun
last_scheduling_decision = Column(UtcDateTime)
dag_hash = Column(String(32))
# Foreign key to LogTemplate. DagRun rows created prior to this column's
# existence have this set to NULL. Later rows automatically populate this on
# insert to point to the latest LogTemplate entry.
Expand All @@ -155,6 +155,11 @@ class DagRun(Base, LoggingMixin):
# This number is incremented only when the DagRun is re-Queued,
# when the DagRun is cleared.
clear_number = Column(Integer, default=0, nullable=False, server_default="0")
serialized_dag_id = Column(
Integer,
ForeignKey("serialized_dag.id", name="dag_run_serialized_dag_fkey", ondelete="SET NULL"),
)
serialized_dag = relationship("SerializedDagModel", back_populates="dag_run", lazy="joined")

# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
Expand Down Expand Up @@ -218,7 +223,7 @@ def __init__(
conf: Any | None = None,
state: DagRunState | None = None,
run_type: str | None = None,
dag_hash: str | None = None,
serialized_dag: SerializedDagModel | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
triggered_by: DagRunTriggeredByType | None = None,
Expand All @@ -242,7 +247,7 @@ def __init__(
else:
self.queued_at = queued_at
self.run_type = run_type
self.dag_hash = dag_hash
self.serialized_dag = serialized_dag
self.creating_job_id = creating_job_id
self.clear_number = 0
self.triggered_by = triggered_by
Expand Down Expand Up @@ -354,6 +359,14 @@ def set_state(self, state: DagRunState) -> None:
def state(self):
return synonym("_state", descriptor=property(self.get_state, self.set_state))

@property
def dag_hash(self):
if self.serialized_dag:
return self.serialized_dag.dag_hash
# TODO: Should we avoid serialized DAG deletion since
# we can have multiple versions of same dag?
return "SerializedDAG Deleted"

@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
"""
Expand Down
13 changes: 5 additions & 8 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@
String,
UniqueConstraint,
and_,
delete,
exc,
or_,
select,
)
from sqlalchemy.orm import backref, foreign, relationship
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql.expression import func, literal

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import TaskNotFound
from airflow.models.base import ID_LEN, Base
from airflow.models.dag import DagModel
from airflow.models.dagcode import DagCode
from airflow.models.dagrun import DagRun
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.settings import COMPRESS_SERIALIZED_DAGS, MIN_SERIALIZED_DAG_UPDATE_INTERVAL, json
Expand Down Expand Up @@ -106,11 +106,8 @@ class SerializedDagModel(Base):
UniqueConstraint("dag_hash", "version_number", name="dag_hash_version_number_unique"),
)

dag_runs = relationship(
DagRun,
primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore
backref=backref("serialized_dag", uselist=False, innerjoin=True),
)
dag_run = relationship("DagRun", back_populates="serialized_dag")
task_instance = relationship("TaskInstance", back_populates="serialized_dag")

dag_model = relationship(
DagModel,
Expand Down Expand Up @@ -311,7 +308,7 @@ def remove_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> None:
:param dag_id: dag_id to be deleted
:param session: ORM Session.
"""
session.execute(cls.__table__.delete().where(cls.dag_id == dag_id))
session.execute(delete(cls).where(cls.dag_id == dag_id))

@classmethod
@internal_api_call
Expand Down
8 changes: 6 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Column,
DateTime,
Float,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
Expand Down Expand Up @@ -1858,8 +1859,11 @@ class TaskInstance(Base, LoggingMixin):
next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))

_task_display_property_value = Column("task_display_name", String(2000), nullable=True)
# If adding new fields here then remember to add them to
# refresh_from_db() or they won't display in the UI correctly
serialized_dag_id = Column(
Integer,
ForeignKey("serialized_dag.id", name="task_instance_serialized_dag_fkey", ondelete="SET NULL"),
)
serialized_dag = relationship("SerializedDagModel", back_populates="task_instance")

__table_args__ = (
Index("ti_dag_state", dag_id, state),
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/taskinstancehistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class TaskInstanceHistory(Base):
next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))

task_display_name = Column("task_display_name", String(2000), nullable=True)
serialized_dag_id = Column(Integer)

def __init__(
self,
Expand All @@ -100,10 +101,9 @@ def __init__(
):
super().__init__()
for column in self.__table__.columns:
if column.name == "id":
if column.name in ["id", "dag_hash"]:
continue
setattr(self, column.name, getattr(ti, column.name))

if state:
self.state = state

Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class MappedClassProtocol(Protocol):
"2.9.0": "1949afb29106",
"2.9.2": "686269002441",
"2.10.0": "22ed7efa9da2",
"3.0.0": "e1ff90d3efe9",
"3.0.0": "4235395d5ec5",
}


Expand Down
2 changes: 1 addition & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,7 +2254,7 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION):
state=DagRunState.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
serialized_dag=SerializedDagModel.get(dag.dag_id),
run_id=run_id,
triggered_by=DagRunTriggeredByType.UI,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a3ab51a373ba0bc58d2713b6df7b3404b3efa951543be23fea624852101c5d34
e01b3a34963248f25786de5c6683f0152fe8592096d4568c3afc54d634e20fb8
Loading

0 comments on commit b1ce903

Please sign in to comment.