diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 26f573be1e3e2b..78bb91113de8c1 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -187,12 +187,12 @@ def create_task_instances( ) session.add(dr) ti = TaskInstance(task=tasks[i], **self.ti_init) + session.add(ti) ti.dag_run = dr ti.note = "placeholder-note" for key, value in self.ti_extras.items(): setattr(ti, key, value) - session.add(ti) tis.append(ti) session.commit() diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py index 2e774bec6d487e..c327e182fa15a7 100644 --- a/tests/api_connexion/schemas/test_task_instance_schema.py +++ b/tests/api_connexion/schemas/test_task_instance_schema.py @@ -65,6 +65,7 @@ def set_attrs(self, session, dag_maker): def test_task_instance_schema_without_sla_and_rendered(self, session): ti = TI(task=self.task, **self.default_ti_init) + session.add(ti) for key, value in self.default_ti_extras.items(): setattr(ti, key, value) serialized_ti = task_instance_schema.dump((ti, None, None)) @@ -109,6 +110,7 @@ def test_task_instance_schema_with_sla_and_rendered(self, session): session.add(sla_miss) session.flush() ti = TI(task=self.task, **self.default_ti_init) + session.add(ti) for key, value in self.default_ti_extras.items(): setattr(ti, key, value) self.task.template_fields = ["partitions"] diff --git a/tests/api_experimental/common/test_mark_tasks.py b/tests/api_experimental/common/test_mark_tasks.py index 9b28136bba2797..83b966c2095f91 100644 --- a/tests/api_experimental/common/test_mark_tasks.py +++ b/tests/api_experimental/common/test_mark_tasks.py @@ -21,7 +21,7 @@ from typing import Callable import pytest -from sqlalchemy.orm import eagerload +from sqlalchemy.orm import joinedload from airflow import models from airflow.api.common.mark_tasks import ( @@ -134,7 +134,7 @@ def snapshot_state(dag, execution_dates): return ( session.query(TI) .join(TI.dag_run) - .options(eagerload(TI.dag_run)) + .options(joinedload(TI.dag_run)) .filter(TI.dag_id == dag.dag_id, DR.execution_date.in_(execution_dates)) .all() ) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index c39265f5759ac9..108ee9700383c1 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -1909,8 +1909,8 @@ def consumer(value): ti.map_index = 0 for map_index in range(1, 3): ti = TI(consumer_op, run_id=dr.run_id, map_index=map_index) - ti.dag_run = dr session.add(ti) + ti.dag_run = dr session.flush() executor = MockExecutor() diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 8a10c28f1c4a5f..fe1c2d58a29e30 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -2115,8 +2115,8 @@ def do_something_else(i): task = ti.task for map_index in range(1, 5): ti = TI(task, run_id=dr.run_id, map_index=map_index) - ti.dag_run = dr session.add(ti) + ti.dag_run = dr session.flush() tis = dr.get_task_instances() for ti in tis: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 9f6447b3ae3d92..0635a5f053c169 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1484,8 +1484,8 @@ def do_something_else(i): ti.map_index = 0 for map_index in range(1, 5): ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) - ti.dag_run = dr session.add(ti) + ti.dag_run = dr session.flush() downstream = ti.task ti = dr.get_task_instance(task_id="do_something_else", map_index=3, session=session) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py index 222dbdbbb49a57..a0d5857f7285ca 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py @@ -66,11 +66,11 @@ def test_serialize(self): username="test", password="test", email=TEST_EMAIL, - roles=[self.role], created_on=timezone.parse(DEFAULT_TIME), changed_on=timezone.parse(DEFAULT_TIME), ) self.session.add(user_model) + user_model.roles = [self.role] self.session.commit() user = self.session.query(User).filter(User.email == TEST_EMAIL).first() deserialized_user = user_collection_item_schema.dump(user) diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index 8ba78a1d31452b..549e703485cd7c 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -159,8 +159,12 @@ def test_serializing_pydantic_dataset_event(session, create_task_instance, creat ds2_event_1 = DatasetEvent(dataset_id=2) ds2_event_2 = DatasetEvent(dataset_id=2) - DagScheduleDatasetReference(dag_id=dag.dag_id, dataset=ds1) - TaskOutletDatasetReference(task_id=task1.task_id, dag_id=dag.dag_id, dataset=ds1) + dag_ds_ref = DagScheduleDatasetReference(dag_id=dag.dag_id) + session.add(dag_ds_ref) + dag_ds_ref.dataset = ds1 + task_ds_ref = TaskOutletDatasetReference(task_id=task1.task_id, dag_id=dag.dag_id) + session.add(task_ds_ref) + task_ds_ref.dataset = ds1 dr.consumed_dataset_events.append(ds1_event) dr.consumed_dataset_events.append(ds2_event_1) diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 2bd652567a07e4..73679d63912539 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -137,8 +137,8 @@ def _expand_tasks(task_instance: str, upstream: str) -> BaseOperator | None: ti.map_index = 0 for map_index in range(1, 5): ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) - ti.dag_run = dr session.add(ti) + ti.dag_run = dr session.flush() tis = dr.get_task_instances(session=session) for ti in tis: diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 2bfc574b640362..077270480042d9 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -493,6 +493,7 @@ def test_set_context_trigger(self, create_dummy_dag, dag_maker, is_a_trigger, se job = Job() t = Trigger("", {}) t.triggerer_job = job + session.add(t) ti.triggerer = t t.task_instance = ti h = FileTaskHandler(base_log_folder=os.fspath(tmp_path))