Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/prefect/server/events/models/composite_trigger_child_firing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,46 @@

from prefect.server.database import PrefectDBInterface, db_injector
from prefect.server.events.schemas.automations import CompositeTrigger, Firing
from prefect.server.utilities.database import get_dialect
from prefect.types._datetime import DateTime, now

if TYPE_CHECKING:
from prefect.server.database.orm_models import ORMCompositeTriggerChildFiring


async def acquire_composite_trigger_lock(
session: AsyncSession,
trigger: CompositeTrigger,
) -> None:
"""
Acquire a transaction-scoped advisory lock for the given composite trigger.

This serializes concurrent child trigger evaluations for the same compound
trigger, preventing a race condition where multiple transactions each see
only their own child firing and neither fires the parent.

The lock is automatically released when the transaction commits or rolls back.
"""
bind = session.get_bind()
if bind is None:
return

# Get the engine from either an Engine or Connection
engine: sa.Engine = bind if isinstance(bind, sa.Engine) else bind.engine # type: ignore[union-attr]
dialect = get_dialect(engine)

if dialect.name == "postgresql":
# Use the trigger's UUID as the lock key
# pg_advisory_xact_lock takes a bigint, so we use the UUID's int representation
# truncated to fit (collision is extremely unlikely and benign)
lock_key = hash(str(trigger.id)) % (2**63)
await session.execute(
sa.text("SELECT pg_advisory_xact_lock(:key)"), {"key": lock_key}
)
# SQLite doesn't support advisory locks, but SQLite also serializes writes
# at the database level, so the race condition is less likely to occur


@db_injector
async def upsert_child_firing(
db: PrefectDBInterface,
Expand Down Expand Up @@ -102,11 +136,22 @@ async def clear_child_firings(
session: AsyncSession,
trigger: CompositeTrigger,
firing_ids: Sequence[UUID],
) -> None:
await session.execute(
sa.delete(db.CompositeTriggerChildFiring).filter(
) -> set[UUID]:
"""
Delete the specified child firings and return the IDs that were actually deleted.

Returns the set of child_firing_ids that were successfully deleted. Callers can
compare this to the expected firing_ids to detect races and avoid double-firing
composite triggers.
"""
result = await session.execute(
sa.delete(db.CompositeTriggerChildFiring)
.filter(
db.CompositeTriggerChildFiring.automation_id == trigger.automation.id,
db.CompositeTriggerChildFiring.parent_trigger_id == trigger.id,
db.CompositeTriggerChildFiring.child_firing_id.in_(firing_ids),
)
.returning(db.CompositeTriggerChildFiring.child_trigger_id)
)

return set(result.scalars().all())
34 changes: 29 additions & 5 deletions src/prefect/server/events/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
read_automation,
)
from prefect.server.events.models.composite_trigger_child_firing import (
acquire_composite_trigger_lock,
clear_child_firings,
clear_old_child_firings,
get_child_firings,
Expand Down Expand Up @@ -65,12 +66,11 @@
from prefect.settings.context import get_current_settings

if TYPE_CHECKING:
import logging

from prefect.server.database.orm_models import ORMAutomationBucket

import logging

logger: "logging.Logger" = get_logger(__name__)
logger = logging.getLogger(__name__)

AutomationID: TypeAlias = UUID
TriggerID: TypeAlias = UUID
Expand Down Expand Up @@ -346,6 +346,11 @@ async def evaluate_composite_trigger(session: AsyncSession, firing: Firing) -> N
)
return

# Acquire an advisory lock to serialize concurrent evaluations for this
# compound trigger. This prevents a race condition where multiple child
# triggers fire concurrently and neither transaction sees both firings.
await acquire_composite_trigger_lock(session, trigger)

# If we're only looking within a certain time horizon, remove any older firings that
# should no longer be considered as satisfying this trigger
if trigger.within is not None:
Expand Down Expand Up @@ -382,8 +387,27 @@ async def evaluate_composite_trigger(session: AsyncSession, firing: Firing) -> N
},
)

# clear by firing id
await clear_child_firings(session, trigger, firing_ids=list(firing_ids))
# Clear by firing id, and only proceed if we won the race to claim them.
# This prevents double-firing when multiple workers evaluate concurrently.
deleted_ids = await clear_child_firings(
session, trigger, firing_ids=list(firing_ids)
)

if len(deleted_ids) != len(firing_ids):
logger.debug(
"Composite trigger %s skipped fire; expected to delete %s firings, "
"actually deleted %s (another worker likely claimed them)",
trigger.id,
len(firing_ids),
len(deleted_ids),
extra={
"automation": automation.id,
"trigger": trigger.id,
"expected_firing_ids": sorted(str(f) for f in firing_ids),
"deleted_firing_ids": sorted(str(f) for f in deleted_ids),
},
)
return

await fire(
session,
Expand Down
121 changes: 121 additions & 0 deletions tests/events/server/triggers/test_composite_triggers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime
from datetime import timedelta
from typing import List
Expand Down Expand Up @@ -1624,3 +1625,123 @@ async def test_sequence_trigger_identical_event_triggers_only_one_fired_does_not
await triggers.reactive_evaluation(ingredients_buy)

act.assert_not_called()


class TestCompoundTriggerConcurrency:
"""Tests for concurrent child trigger evaluation race condition fix."""

@pytest.fixture
async def compound_automation_concurrent(
self,
automations_session: AsyncSession,
cleared_buckets: None,
cleared_automations: None,
) -> Automation:
"""Compound trigger requiring all child triggers to fire."""
compound_automation = Automation(
name="Compound Automation Concurrency Test",
trigger=CompoundTrigger(
require="all",
within=timedelta(minutes=5),
triggers=[
EventTrigger(
expect={"event.A"},
match={"prefect.resource.id": "*"},
posture=Posture.Reactive,
threshold=1,
),
EventTrigger(
expect={"event.B"},
match={"prefect.resource.id": "*"},
posture=Posture.Reactive,
threshold=1,
),
],
),
actions=[actions.DoNothing()],
)

persisted = await automations.create_automation(
session=automations_session, automation=compound_automation
)
compound_automation.created = persisted.created
compound_automation.updated = persisted.updated
triggers.load_automation(persisted)
await automations_session.commit()

return compound_automation

async def test_compound_trigger_does_not_double_fire_when_children_race(
self,
act: mock.AsyncMock,
compound_automation_concurrent: Automation,
start_of_test: DateTime,
):
"""
Regression test for compound trigger double-firing when child firings race.

Verifies that when two child trigger events are processed concurrently,
the compound trigger fires exactly once. The DELETE ... RETURNING fix
ensures only one worker proceeds to fire the parent trigger.
"""
event_a = ReceivedEvent(
occurred=start_of_test + timedelta(microseconds=1),
event="event.A",
resource={"prefect.resource.id": "test.resource"},
id=uuid4(),
)
event_b = ReceivedEvent(
occurred=start_of_test + timedelta(microseconds=2),
event="event.B",
resource={"prefect.resource.id": "test.resource"},
id=uuid4(),
)

# Process both events concurrently
await asyncio.gather(
triggers.reactive_evaluation(event_a),
triggers.reactive_evaluation(event_b),
)

# The compound trigger should fire exactly once
act.assert_called_once()

firing: Firing = act.call_args.args[0]
assert isinstance(firing.trigger, CompoundTrigger)
assert firing.trigger.id == compound_automation_concurrent.trigger.id

async def test_concurrent_child_firings_still_triggers_parent(
self,
act: mock.AsyncMock,
compound_automation_concurrent: Automation,
start_of_test: DateTime,
):
"""
Verify that when two child trigger events arrive nearly simultaneously,
the compound trigger still fires. This tests that the race condition fix
doesn't prevent legitimate firings.
"""
event_a = ReceivedEvent(
occurred=start_of_test + timedelta(microseconds=1),
event="event.A",
resource={"prefect.resource.id": "test.resource"},
id=uuid4(),
)
event_b = ReceivedEvent(
occurred=start_of_test + timedelta(microseconds=2),
event="event.B",
resource={"prefect.resource.id": "test.resource"},
id=uuid4(),
)

# Process both events concurrently to simulate the race condition
await asyncio.gather(
triggers.reactive_evaluation(event_a),
triggers.reactive_evaluation(event_b),
)

# The compound trigger should fire exactly once
act.assert_called_once()

firing: Firing = act.call_args.args[0]
assert firing.trigger.id == compound_automation_concurrent.trigger.id
Loading