Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/models/test_trigger.py$|
^dev/airflow_perf/scheduler_dag_execution_timing.py$|
^providers/openlineage/.*\.py$|
^task_sdk.*\.py$
Expand Down
69 changes: 38 additions & 31 deletions airflow-core/tests/unit/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
import pytz
from cryptography.fernet import Fernet
from sqlalchemy import delete, func, select

from airflow._shared.timezones import timezone
from airflow.jobs.job import Job
Expand Down Expand Up @@ -61,21 +62,21 @@ def session():

@pytest.fixture(autouse=True)
def clear_db(session):
session.query(TaskInstance).delete()
session.query(AssetWatcherModel).delete()
session.query(Callback).delete()
session.query(Trigger).delete()
session.query(AssetModel).delete()
session.query(AssetEvent).delete()
session.query(Job).delete()
session.execute(delete(TaskInstance))
session.execute(delete(AssetWatcherModel))
session.execute(delete(Callback))
session.execute(delete(Trigger))
session.execute(delete(AssetModel))
session.execute(delete(AssetEvent))
session.execute(delete(Job))
yield session
session.query(TaskInstance).delete()
session.query(AssetWatcherModel).delete()
session.query(Callback).delete()
session.query(Trigger).delete()
session.query(AssetModel).delete()
session.query(AssetEvent).delete()
session.query(Job).delete()
session.execute(delete(TaskInstance))
session.execute(delete(AssetWatcherModel))
session.execute(delete(Callback))
session.execute(delete(Trigger))
session.execute(delete(AssetModel))
session.execute(delete(AssetEvent))
session.execute(delete(Job))
session.commit()


Expand Down Expand Up @@ -121,7 +122,7 @@ def test_clean_unused(session, create_task_instance):
session.add(trigger5)
session.add(trigger6)
session.commit()
assert session.query(Trigger).count() == 6
assert session.scalar(select(func.count()).select_from(Trigger)) == 6
# Tie one to a fake TaskInstance that is not deferred, and one to one that is
task_instance = create_task_instance(
session=session, task_id="fake", state=State.DEFERRED, logical_date=timezone.utcnow()
Expand Down Expand Up @@ -150,7 +151,7 @@ def test_clean_unused(session, create_task_instance):
asset.add_trigger(trigger5, "test_asset_watcher2")
session.add(asset)
session.commit()
assert session.query(AssetModel).count() == 1
assert session.scalar(select(func.count()).select_from(AssetModel)) == 1

# Create callback with trigger
callback = TriggererCallback(
Expand All @@ -162,7 +163,7 @@ def test_clean_unused(session, create_task_instance):

# Run clear operation
Trigger.clean_unused()
results = session.query(Trigger).all()
results = session.scalars(select(Trigger)).all()
assert len(results) == 4
assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id}

Expand Down Expand Up @@ -196,7 +197,10 @@ def test_submit_event(mock_callback_handle_event, session, create_task_instance)
session.commit()

# Check that the asset has 0 event prior to sending an event to the trigger
assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0
assert (
session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asset.id))
== 0
)

# Create event
payload = "payload"
Expand All @@ -210,8 +214,11 @@ def test_submit_event(mock_callback_handle_event, session, create_task_instance)
assert task_instance.state == State.SCHEDULED
assert task_instance.next_kwargs == {"event": payload, "cheesecake": True}
# Check that the asset has received an event
assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1
asset_event = session.query(AssetEvent).filter_by(asset_id=asset.id).first()
assert (
session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asset.id))
== 1
)
asset_event = session.scalar(select(AssetEvent).where(AssetEvent.asset_id == asset.id))
assert asset_event.extra == {"from_trigger": True, "payload": payload}

# Check that the callback's handle_event was called
Expand All @@ -233,7 +240,7 @@ def test_submit_failure(session, create_task_instance):
# Call submit_event
Trigger.submit_failure(trigger.id, session=session)
# Check that the task instance is now scheduled to fail
updated_task_instance = session.query(TaskInstance).one()
updated_task_instance = session.scalar(select(TaskInstance))
assert updated_task_instance.state == State.SCHEDULED
assert updated_task_instance.next_method == "__fail__"

Expand Down Expand Up @@ -272,7 +279,7 @@ def get_xcoms(ti):

# now for the real test
# first check initial state
ti: TaskInstance = session.query(TaskInstance).one()
ti: TaskInstance = session.scalar(select(TaskInstance))
assert ti.state == "deferred"
assert get_xcoms(ti) == []

Expand All @@ -285,7 +292,7 @@ def get_xcoms(ti):
# commit changes made by submit event and expire all cache to read from db.
session.flush()
# Check that the task instance is now correct
ti = session.query(TaskInstance).one()
ti = session.scalar(select(TaskInstance))
assert ti.state == expected
assert ti.next_kwargs is None
assert ti.end_date == now
Expand Down Expand Up @@ -370,26 +377,26 @@ def test_assign_unassigned(session, create_task_instance):
session.add(ti_trigger_unassigned_to_triggerer)
assert trigger_unassigned_to_triggerer.triggerer_id is None
session.commit()
assert session.query(Trigger).count() == 4
assert session.scalar(select(func.count()).select_from(Trigger)) == 4
Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
session.expire_all()
# Check that trigger on killed triggerer and unassigned trigger are assigned to new triggerer
assert (
session.query(Trigger).filter(Trigger.id == trigger_on_killed_triggerer.id).one().triggerer_id
session.scalar(select(Trigger).where(Trigger.id == trigger_on_killed_triggerer.id)).triggerer_id
== new_triggerer.id
)
assert (
session.query(Trigger).filter(Trigger.id == trigger_unassigned_to_triggerer.id).one().triggerer_id
session.scalar(select(Trigger).where(Trigger.id == trigger_unassigned_to_triggerer.id)).triggerer_id
== new_triggerer.id
)
# Check that trigger on healthy triggerer still assigned to existing triggerer
assert (
session.query(Trigger).filter(Trigger.id == trigger_on_healthy_triggerer.id).one().triggerer_id
session.scalar(select(Trigger).where(Trigger.id == trigger_on_healthy_triggerer.id)).triggerer_id
== healthy_triggerer.id
)
# Check that trigger on unhealthy triggerer is assigned to new triggerer
assert (
session.query(Trigger).filter(Trigger.id == trigger_on_unhealthy_triggerer.id).one().triggerer_id
session.scalar(select(Trigger).where(Trigger.id == trigger_on_unhealthy_triggerer.id)).triggerer_id
== new_triggerer.id
)

Expand Down Expand Up @@ -453,7 +460,7 @@ def test_get_sorted_triggers_same_priority_weight(session, create_task_instance)
)
session.add(trigger_callback)
session.commit()
assert session.query(Trigger).count() == 5
assert session.scalar(select(func.count()).select_from(Trigger)) == 5
# Create assets
asset = AssetModel("test")
asset.add_trigger(trigger_asset, "test_asset_watcher")
Expand Down Expand Up @@ -534,7 +541,7 @@ def test_get_sorted_triggers_different_priority_weights(session, create_task_ins
session.add(TI_new)

session.commit()
assert session.query(Trigger).count() == 5
assert session.scalar(select(func.count()).select_from(Trigger)) == 5

trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)

Expand Down Expand Up @@ -605,7 +612,7 @@ def test_get_sorted_triggers_dont_starve_for_ha(session, create_task_instance):
asset_triggers.append(trigger)

session.commit()
assert session.query(Trigger).count() == 60
assert session.scalar(select(func.count()).select_from(Trigger)) == 60

# Mock max_trigger_to_select_per_loop to 5 for testing
with patch.object(Trigger, "max_trigger_to_select_per_loop", 5):
Expand Down