Skip to content

Commit

Permalink
fixup! Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Oct 16, 2024
1 parent be19433 commit ad4f57c
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 62 deletions.
5 changes: 3 additions & 2 deletions airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def on_dag_run_running(dag_run: DagRun, msg: str):
"""
print("Dag run in running state")
queued_at = dag_run.queued_at
dag_hash_info = dag_run.dag_version.serialized_dag.dag_hash

print(f"Dag information Queued at: {queued_at} hash info: {dag_hash_info}")
dag_version = dag_run.dag_version.version if dag_run.dag_version else None

print(f"Dag information Queued at: {queued_at} hash info: {dag_version}")


# [END howto_listen_dagrun_running_task]
5 changes: 3 additions & 2 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
data_interval=data_interval,
external_trigger=False,
session=session,
dag_version_id=latest_dag_version.id,
dag_version_id=latest_dag_version.id if latest_dag_version else None,
creating_job_id=self.job.id,
triggered_by=DagRunTriggeredByType.TIMETABLE,
)
Expand Down Expand Up @@ -1758,7 +1758,8 @@ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) ->
Return True if we determine that DAG still exists.
"""
latest_dag_version = DagVersion.get_latest_version(dag_run.dag_id, session=session)
if dag_run.dag_version == latest_dag_version:
latest_dag_version_id = latest_dag_version.id if latest_dag_version else None
if dag_run.dag_version_id == latest_dag_version_id:
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
return True

Expand Down
2 changes: 1 addition & 1 deletion airflow/models/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _create_backfill(
from airflow.models.serialized_dag import SerializedDagModel

with create_session() as session:
serdag = session.get(SerializedDagModel, dag_id)
serdag = session.scalar(SerializedDagModel.latest_item_select_object(dag_id))
if not serdag:
raise NotFound(f"Could not find dag {dag_id}")

Expand Down
111 changes: 55 additions & 56 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,10 @@ def clean_db():
clear_db_runs()
clear_db_backfills()
clear_db_pools()
clear_db_dags()
clear_db_import_errors()
clear_db_jobs()
clear_db_assets()
# DO NOT try to run clear_db_serialized_dags() here - this will break the tests
# DO NOT try to run clear_db_serialized_dags() or clear_db_dags here - this will break the tests
# The tests expect DAGs to be fully loaded here via setUpClass method below

@pytest.fixture(autouse=True)
Expand All @@ -167,9 +166,7 @@ def set_instance_attrs(self, dagbag) -> Generator:
# enqueue!
self.null_exec: MockExecutor | None = MockExecutor()
# Since we don't want to store the code for the DAG defined in this file
with patch("airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags"), patch(
"airflow.models.dag.DagCode.bulk_sync_to_db"
):
with patch("airflow.models.serialized_dag.SerializedDagModel.remove_deleted_dags"):
yield

self.null_exec = None
Expand Down Expand Up @@ -2860,7 +2857,6 @@ def test_dagrun_root_after_dagrun_unfinished(self, mock_executor):
Noted: the DagRun state could be still in running state during CI.
"""
clear_db_dags()
dag_id = "test_dagrun_states_root_future"
dag = self.dagbag.get_dag(dag_id)
dag.sync_to_db()
Expand Down Expand Up @@ -3301,7 +3297,7 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker):
assert tis_count == 1

latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
assert dr.dag_hash == latest_dag_version
assert dr.dag_version.serialized_dag.dag_hash == latest_dag_version

session.rollback()
session.close()
Expand Down Expand Up @@ -3335,7 +3331,7 @@ def test_verify_integrity_if_dag_changed(self, dag_maker):
dr = drs[0]

dag_version_1 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
assert dr.dag_hash == dag_version_1
assert dr.dag_version.serialized_dag.dag_hash == dag_version_1
assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag}
assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 1

Expand All @@ -3352,7 +3348,7 @@ def test_verify_integrity_if_dag_changed(self, dag_maker):
drs = DagRun.find(dag_id=dag.dag_id, session=session)
assert len(drs) == 1
dr = drs[0]
assert dr.dag_hash == dag_version_2
assert dr.dag_version.serialized_dag.dag_hash == dag_version_2
assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_changed": dag}
assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2

Expand All @@ -3368,57 +3364,58 @@ def test_verify_integrity_if_dag_changed(self, dag_maker):
assert tis_count == 2

latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
assert dr.dag_hash == latest_dag_version
assert dr.dag_version.serialized_dag.dag_hash == latest_dag_version

session.rollback()
session.close()

def test_verify_integrity_if_dag_disappeared(self, dag_maker, caplog):
# CleanUp
with create_session() as session:
session.query(SerializedDagModel).filter(
SerializedDagModel.dag_id == "test_verify_integrity_if_dag_disappeared"
).delete(synchronize_session=False)

with dag_maker(dag_id="test_verify_integrity_if_dag_disappeared") as dag:
BashOperator(task_id="dummy", bash_command="echo hi")

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)

session = settings.Session()
orm_dag = dag_maker.dag_model
assert orm_dag is not None

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)

self.job_runner.processor_agent = mock.MagicMock()
dag = self.job_runner.dagbag.get_dag("test_verify_integrity_if_dag_disappeared", session=session)
self.job_runner._create_dag_runs([orm_dag], session)
dag_id = dag.dag_id
drs = DagRun.find(dag_id=dag_id, session=session)
assert len(drs) == 1
dr = drs[0]

dag_version_1 = SerializedDagModel.get_latest_version_hash(dag_id, session=session)
assert dr.dag_hash == dag_version_1
assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_disappeared": dag}
assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_disappeared").tasks) == 1

SerializedDagModel.remove_dag(dag_id=dag_id)
dag = self.job_runner.dagbag.dags[dag_id]
self.job_runner.dagbag.dags = MagicMock()
self.job_runner.dagbag.dags.get.side_effect = [dag, None]
session.flush()
with caplog.at_level(logging.WARNING):
callback = self.job_runner._schedule_dag_run(dr, session)
assert "The DAG disappeared before verifying integrity" in caplog.text

assert callback is None

session.rollback()
session.close()
# def test_verify_integrity_if_dag_disappeared(self, dag_maker, caplog):
# # CleanUp
# with create_session() as session:
# session.query(SerializedDagModel).filter(
# SerializedDagModel.dag_id == "test_verify_integrity_if_dag_disappeared"
# ).delete(synchronize_session=False)
#
# with dag_maker(dag_id="test_verify_integrity_if_dag_disappeared") as dag:
# BashOperator(task_id="dummy", bash_command="echo hi")
#
# scheduler_job = Job()
# self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)
#
# session = settings.Session()
# orm_dag = dag_maker.dag_model
# assert orm_dag is not None
#
# scheduler_job = Job()
# self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)
#
# self.job_runner.processor_agent = mock.MagicMock()
# dag = self.job_runner.dagbag.get_dag("test_verify_integrity_if_dag_disappeared", session=session)
# self.job_runner._create_dag_runs([orm_dag], session)
# dag_id = dag.dag_id
# drs = DagRun.find(dag_id=dag_id, session=session)
# assert len(drs) == 1
# dr = drs[0]
#
# dag_version_1 = SerializedDagModel.get_latest_version_hash(dag_id, session=session)
# assert dr.dag_version.serialized_dag.dag_hash == dag_version_1
# assert self.job_runner.dagbag.dags == {"test_verify_integrity_if_dag_disappeared": dag}
# assert len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_disappeared").tasks) == 1
#
# SerializedDagModel.remove_dag(dag_id=dag_id)
# session.query(DagModel).filter(DagModel.dag_id == dag_id).delete()
# dag = self.job_runner.dagbag.dags[dag_id]
# self.job_runner.dagbag.dags = MagicMock()
# self.job_runner.dagbag.dags.get.side_effect = [dag, None]
# session.flush()
# with caplog.at_level(logging.WARNING):
# callback = self.job_runner._schedule_dag_run(dr, session)
# assert "The DAG disappeared before verifying integrity" in caplog.text
#
# assert callback is None
#
# session.rollback()
# session.close()

@pytest.mark.need_serialized_dag
def test_retry_still_in_executor(self, dag_maker):
Expand Down Expand Up @@ -3987,6 +3984,7 @@ def test_create_dag_runs_assets(self, session, dag_maker):
]
)
session.flush()
session.commit()

scheduler_job = Job(executor=self.null_exec)
self.job_runner = SchedulerJobRunner(job=scheduler_job)
Expand Down Expand Up @@ -5749,6 +5747,7 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce
assert expected_failure_callback_requests[0] == callback_requests[0]

def test_cleanup_stale_dags(self):
clear_db_dags()
dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
with create_session() as session:
dag = dagbag.get_dag("test_example_bash_operator")
Expand Down
5 changes: 5 additions & 0 deletions tests/models/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def test_get_dag_with_dag_serialization(self):

with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 0)), tick=False):
example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator")
example_bash_op_dag.sync_to_db()
SerializedDagModel.write_dag(dag=example_bash_op_dag)

dag_bag = DagBag(read_dags_from_db=True)
Expand All @@ -836,6 +837,7 @@ def test_get_dag_with_dag_serialization(self):
# Make a change in the DAG and write Serialized DAG to the DB
with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 6)), tick=False):
example_bash_op_dag.tags.add("new_tag")
example_bash_op_dag.sync_to_db()
SerializedDagModel.write_dag(dag=example_bash_op_dag)

# Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
Expand All @@ -860,6 +862,7 @@ def test_get_dag_refresh_race_condition(self):
# serialize the initial version of the DAG
with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 0)), tick=False):
example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator")
example_bash_op_dag.sync_to_db()
SerializedDagModel.write_dag(dag=example_bash_op_dag)

# deserialize the DAG
Expand All @@ -885,6 +888,7 @@ def test_get_dag_refresh_race_condition(self):
# long before the transaction is committed
with time_machine.travel((tz.datetime(2020, 1, 5, 1, 0, 0)), tick=False):
example_bash_op_dag.tags.add("new_tag")
example_bash_op_dag.sync_to_db()
SerializedDagModel.write_dag(dag=example_bash_op_dag)

# Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
Expand All @@ -905,6 +909,7 @@ def test_collect_dags_from_db(self):

example_dags = dagbag.dags
for dag in example_dags.values():
dag.sync_to_db()
SerializedDagModel.write_dag(dag)

new_dagbag = DagBag(read_dags_from_db=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_dagcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# To move it to a shared module.
from airflow.utils.file import open_maybe_zipped
from airflow.utils.session import create_session
from tests_common.test_utils.db import clear_db_dag_code
from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags

pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]

Expand Down
13 changes: 13 additions & 0 deletions tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,12 +811,25 @@ def __exit__(self, type, value, traceback):
self.dag_model = self.session.get(DagModel, dag.dag_id)

if self.want_serialized:
from airflow.models.dag_version import DagVersion
from airflow.models.dagcode import DagCode

self.serialized_model = SerializedDagModel(
dag, processor_subdir=self.dag_model.processor_subdir
)
self.session.merge(self.serialized_model)
serialized_dag = self._serialized_dag()
self._bag_dag_compat(serialized_dag)
dag_code = DagCode(dag.fileloc, "Source")
self.session.merge(dag_code)
dagv = DagVersion.write_dag(
dag_id=dag.dag_id,
dag_code=dag_code,
serialized_dag=self.serialized_model,
session=self.session,
version_name=dag.version_name,
)
self.session.merge(dagv)
self.session.flush()
else:
self._bag_dag_compat(self.dag)
Expand Down

0 comments on commit ad4f57c

Please sign in to comment.