Skip to content

Commit

Permalink
a little refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Oct 3, 2024
1 parent b9f890a commit d769be2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 25 deletions.
13 changes: 6 additions & 7 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def write_dag(
log.debug("DAG: %s written to the DB", dag.dag_id)
return True

@classmethod
def _latest_item_select_object(cls, dag_id):
return select(cls).where(cls.dag_id == dag_id).order_by(cls.id.desc()).limit(1)

@classmethod
@provide_session
def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDAG]:
Expand Down Expand Up @@ -372,7 +376,7 @@ def get(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDagModel
:param dag_id: the DAG to fetch
:param session: ORM Session
"""
return session.scalar(select(cls).where(cls.dag_id == dag_id).order_by(cls.id.desc()).limit(1))
return session.scalar(cls._latest_item_select_object(dag_id))

@staticmethod
@provide_session
Expand Down Expand Up @@ -491,12 +495,7 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
def get_serialized_dag(dag_id: str, task_id: str, session: Session = NEW_SESSION) -> Operator | None:
try:
# get the latest version of the DAG
model = session.scalar(
select(SerializedDagModel)
.where(SerializedDagModel.dag_id == dag_id)
.order_by(SerializedDagModel.id.desc())
.limit(1)
)
model = session.scalar(SerializedDagModel._latest_item_select_object(dag_id))
if model:
return model.dag.get_task(task_id)
except (exc.NoResultFound, TaskNotFound):
Expand Down
24 changes: 6 additions & 18 deletions tests/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,12 @@ def test_serialized_dag_is_updated_if_dag_is_changed(self):
assert dag_updated is True

with create_session() as session:
s_dag = session.scalar(
select(SDM).where(SDM.dag_id == example_bash_op_dag.dag_id).order_by(SDM.id.desc()).limit(1)
)
s_dag = session.scalar(SDM._latest_item_select_object(example_bash_op_dag.dag_id))

# Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
# column is not updated
dag_updated = SDM.write_dag(dag=example_bash_op_dag)
s_dag_1 = session.scalar(
select(SDM).where(SDM.dag_id == example_bash_op_dag.dag_id).order_by(SDM.id.desc()).limit(1)
)
s_dag_1 = session.scalar(SDM._latest_item_select_object(example_bash_op_dag.dag_id))

assert s_dag_1.dag_hash == s_dag.dag_hash
assert s_dag.last_updated == s_dag_1.last_updated
Expand All @@ -121,9 +117,7 @@ def test_serialized_dag_is_updated_if_dag_is_changed(self):
assert example_bash_op_dag.tags == {"example", "example2", "new_tag"}

dag_updated = SDM.write_dag(dag=example_bash_op_dag)
s_dag_2 = session.scalar(
select(SDM).where(SDM.dag_id == example_bash_op_dag.dag_id).order_by(SDM.id.desc()).limit(1)
)
s_dag_2 = session.scalar(SDM._latest_item_select_object(example_bash_op_dag.dag_id))

assert s_dag.last_updated != s_dag_2.last_updated
assert s_dag.dag_hash != s_dag_2.dag_hash
Expand All @@ -139,16 +133,12 @@ def test_serialized_dag_is_updated_if_processor_subdir_changed(self):
assert dag_updated is True

with create_session() as session:
s_dag = session.scalar(
select(SDM).where(SDM.dag_id == example_bash_op_dag.dag_id).order_by(SDM.id.desc()).limit(1)
)
s_dag = session.scalar(SDM._latest_item_select_object(example_bash_op_dag.dag_id))

# Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
# column is not updated
dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir="/tmp/test")
s_dag_1 = session.scalar(
select(SDM).where(SDM.dag_id == example_bash_op_dag.dag_id).order_by(SDM.id.desc()).limit(1)
)
s_dag_1 = session.scalar(SDM._latest_item_select_object(example_bash_op_dag.dag_id))

assert s_dag_1.dag_hash == s_dag.dag_hash
assert s_dag.last_updated == s_dag_1.last_updated
Expand All @@ -157,9 +147,7 @@ def test_serialized_dag_is_updated_if_processor_subdir_changed(self):

# Update DAG
dag_updated = SDM.write_dag(dag=example_bash_op_dag, processor_subdir="/tmp/other")
s_dag_2 = session.scalar(
select(SDM).where(SDM.dag_id == example_bash_op_dag.dag_id).order_by(SDM.id.desc()).limit(1)
)
s_dag_2 = session.scalar(SDM._latest_item_select_object(example_bash_op_dag.dag_id))

assert s_dag.processor_subdir != s_dag_2.processor_subdir
assert dag_updated is True
Expand Down

0 comments on commit d769be2

Please sign in to comment.