Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/models/test_renderedtifields.py$|
^task_sdk.*\.py$
pass_filenames: true
- id: update-supported-versions
Expand Down
32 changes: 17 additions & 15 deletions airflow-core/tests/unit/models/test_renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def test_delete_old_records(
session.add_all(rtif_list)
session.flush()

result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
result = session.scalars(
select(RTIF).where(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id)
).all()

for rtif in rtif_list:
assert rtif in result
Expand All @@ -270,7 +272,9 @@ def test_delete_old_records(

with assert_queries_count(expected_query_count):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
result = session.scalars(
select(RTIF).where(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id)
).all()
assert remaining_rtifs == len(result)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -302,14 +306,16 @@ def test_delete_old_records_mapped(
session.add(RTIF(ti))
session.flush()

result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
result = session.scalars(select(RTIF).where(RTIF.dag_id == dag.dag_id)).all()
assert len(result) == num_runs * 2

with assert_queries_count(expected_query_count):
RTIF.delete_old_records(
task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session
)
result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all()
result = session.scalars(
select(RTIF).where(RTIF.dag_id == dag.dag_id, RTIF.task_id == mapped.task_id)
).all()
rtif_num_runs = Counter(rtif.run_id for rtif in result)
assert len(rtif_num_runs) == remaining_rtifs
# Check that we have _all_ the data for each row
Expand All @@ -322,7 +328,7 @@ def test_write(self, dag_maker):
Variable.set(key="test_key", value="test_val")

session = settings.Session()
result = session.query(RTIF).all()
result = session.scalars(select(RTIF)).all()
assert result == []

with dag_maker("test_write"):
Expand All @@ -334,15 +340,13 @@ def test_write(self, dag_maker):

rtif = RTIF(ti)
rtif.write()
result = (
session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
.filter(
result = session.execute(
select(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).where(
RTIF.dag_id == rtif.dag_id,
RTIF.task_id == rtif.task_id,
RTIF.run_id == rtif.run_id,
)
.first()
)
).first()
assert result == ("test_write", "test", {"bash_command": "echo test_val", "env": None, "cwd": None})

# Test that overwrite saves new values to the DB
Expand All @@ -357,15 +361,13 @@ def test_write(self, dag_maker):
rtif_updated = RTIF(ti)
rtif_updated.write()

result_updated = (
session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
.filter(
result_updated = session.execute(
select(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).where(
RTIF.dag_id == rtif_updated.dag_id,
RTIF.task_id == rtif_updated.task_id,
RTIF.run_id == rtif_updated.run_id,
)
.first()
)
).first()
assert result_updated == (
"test_write",
"test",
Expand Down