Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@

from __future__ import annotations

import json
import logging
import traceback
from typing import TYPE_CHECKING, Any, NamedTuple, cast
from typing import TYPE_CHECKING, NamedTuple, cast

from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
from sqlalchemy.exc import OperationalError
Expand All @@ -53,7 +52,8 @@
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef
from airflow.serialization.serialized_objects import BaseSerialization, SerializedAssetWatcher
from airflow.serialization.serialized_objects import SerializedAssetWatcher
from airflow.triggers.base import BaseEventTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
Expand Down Expand Up @@ -758,7 +758,7 @@ def add_asset_trigger_references(
else []
)
trigger_hash_to_trigger_dict: dict[int, dict] = {
self._get_trigger_hash(
BaseEventTrigger.hash(
watcher.trigger["classpath"], watcher.trigger["kwargs"]
): watcher.trigger
for watcher in asset_watchers
Expand All @@ -768,7 +768,7 @@ def add_asset_trigger_references(

asset_model = assets[name_uri]
trigger_hash_from_asset_model: set[int] = {
self._get_trigger_hash(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
BaseEventTrigger.hash(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
Expand Down Expand Up @@ -796,7 +796,7 @@ def add_asset_trigger_references(
for trigger_hash in trigger_hashes
}
orm_triggers: dict[int, Trigger] = {
self._get_trigger_hash(trigger.classpath, trigger.kwargs): trigger
BaseEventTrigger.hash(trigger.classpath, trigger.kwargs): trigger
for trigger in session.scalars(
select(Trigger).where(
tuple_(Trigger.classpath, Trigger.encrypted_kwargs).in_(all_trigger_keys)
Expand All @@ -817,7 +817,7 @@ def add_asset_trigger_references(
]
session.add_all(new_trigger_models)
orm_triggers.update(
(self._get_trigger_hash(trigger.classpath, trigger.kwargs), trigger)
(BaseEventTrigger.hash(trigger.classpath, trigger.kwargs), trigger)
for trigger in new_trigger_models
)

Expand All @@ -835,7 +835,7 @@ def add_asset_trigger_references(
asset_model.triggers = [
trigger
for trigger in asset_model.triggers
if self._get_trigger_hash(trigger.classpath, trigger.kwargs) not in trigger_hashes
if BaseEventTrigger.hash(trigger.classpath, trigger.kwargs) not in trigger_hashes
]

# Remove references from assets no longer used
Expand All @@ -845,15 +845,3 @@ def add_asset_trigger_references(
for asset_model in orphan_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []

@staticmethod
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.
We do not want to move this logic in a `__hash__` method in `BaseTrigger` because we do not want to
make the triggers hashable. The reason being, when the triggerer retrieve the list of triggers, we do
not want it dedupe them. When used to defer tasks, 2 triggers can have the same classpath and kwargs.
This is not true for event driven scheduling.
"""
return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
26 changes: 6 additions & 20 deletions airflow/example_dags/example_asset_with_watchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,17 @@

from __future__ import annotations

import os

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk import Asset, AssetWatcher

file_path = "/tmp/test"

with DAG(
dag_id="example_create_file",
catchup=False,
):

@task
def create_file():
with open(file_path, "w") as file:
file.write("This is an example file.\n")

chain(create_file())
trigger = FileDeleteTrigger(filepath=file_path)
asset = Asset("example_asset", watchers=[AssetWatcher(name="test_asset_watcher", trigger=trigger)])

trigger = FileTrigger(filepath=file_path, poke_interval=10)
asset = Asset("example_asset", watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])

with DAG(
dag_id="example_asset_with_watchers",
Expand All @@ -52,8 +39,7 @@ def create_file():
):

@task
def delete_file():
if os.path.exists(file_path):
os.remove(file_path) # Delete the file
def test_task():
print("Hello world")

chain(delete_file())
chain(test_task())
3 changes: 2 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
from airflow.sdk.types import Operator
from airflow.serialization.json_schema import Validator
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
from airflow.triggers.base import BaseEventTrigger

HAS_KUBERNETES: bool
try:
Expand Down Expand Up @@ -259,7 +260,7 @@ def _encode_watcher(watcher: AssetWatcher):
"trigger": _encode_trigger(watcher.trigger),
}

def _encode_trigger(trigger: BaseTrigger | dict):
def _encode_trigger(trigger: BaseEventTrigger | dict):
if isinstance(trigger, dict):
return trigger
classpath, kwargs = trigger.serialize()
Expand Down
22 changes: 22 additions & 0 deletions airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import abc
import json
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
Expand Down Expand Up @@ -126,6 +127,27 @@ def __repr__(self) -> str:
return self.repr(classpath, kwargs)


class BaseEventTrigger(BaseTrigger):
"""
Base class for triggers used to schedule DAGs based on external events.

``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify triggers compatible with
event-driven scheduling.
"""

@staticmethod
def hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.

We do not want to have this logic in ``BaseTrigger`` because, when used to defer tasks, 2 triggers
can have the same classpath and kwargs. This is not true for event driven scheduling.
"""
from airflow.serialization.serialized_objects import BaseSerialization

return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))


class TriggerEvent:
"""
Something that a trigger can fire when its conditions are met.
Expand Down
60 changes: 57 additions & 3 deletions providers/standard/src/airflow/providers/standard/triggers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@
import asyncio
import datetime
import os
import typing
from collections.abc import AsyncIterator
from glob import glob
from typing import Any

from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.triggers.base import BaseEventTrigger, BaseTrigger, TriggerEvent
else:
from airflow.triggers.base import ( # type: ignore
BaseTrigger,
BaseTrigger as BaseEventTrigger,
TriggerEvent,
)


class FileTrigger(BaseTrigger):
Expand Down Expand Up @@ -60,7 +69,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the relevant files are found."""
while True:
for path in glob(self.filepath, recursive=self.recursive):
Expand All @@ -75,3 +84,48 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
yield TriggerEvent(True)
return
await asyncio.sleep(self.poke_interval)


class FileDeleteTrigger(BaseEventTrigger):
"""
A trigger that fires exactly once after it finds the requested file and then delete the file.

The difference between ``FileTrigger`` and ``FileDeleteTrigger`` is ``FileDeleteTrigger`` can only find a
specific file.

:param filepath: File (relative to the base path set within the connection).
:param poke_interval: Time that the job should wait in between each try
"""

def __init__(
self,
filepath: str,
poke_interval: float = 5.0,
**kwargs,
):
super().__init__()
self.filepath = filepath
self.poke_interval = poke_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize FileDeleteTrigger arguments and classpath."""
return (
"airflow.providers.standard.triggers.file.FileDeleteTrigger",
{
"filepath": self.filepath,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the relevant file is found."""
while True:
if os.path.isfile(self.filepath):
mod_time_f = os.path.getmtime(self.filepath)
mod_time = datetime.datetime.fromtimestamp(mod_time_f).strftime("%Y%m%d%H%M%S")
self.log.info("Found file %s last modified: %s", self.filepath, mod_time)
os.remove(self.filepath)
self.log.info("File %s has been deleted", self.filepath)
yield TriggerEvent(True)
return
await asyncio.sleep(self.poke_interval)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import pytest

from airflow.providers.standard.triggers.file import FileTrigger
from airflow.providers.standard.triggers.file import FileDeleteTrigger, FileTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS


class TestFileTrigger:
Expand Down Expand Up @@ -62,3 +63,44 @@ async def test_task_file_trigger(self, tmp_path):

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0")
class TestFileDeleteTrigger:
FILE_PATH = "/files/dags/example_async_file.py"

def test_serialization(self):
"""Asserts that the trigger correctly serializes its arguments and classpath."""
trigger = FileDeleteTrigger(filepath=self.FILE_PATH, poll_interval=5)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.standard.triggers.file.FileDeleteTrigger"
assert kwargs == {
"filepath": self.FILE_PATH,
"poke_interval": 5,
}

@pytest.mark.asyncio
async def test_file_delete_trigger(self, tmp_path):
"""Asserts that the trigger goes off on or after file is found and that the files gets deleted."""
tmp_dir = tmp_path / "test_dir"
tmp_dir.mkdir()
p = tmp_dir / "hello.txt"

trigger = FileDeleteTrigger(
filepath=str(p.resolve()),
poke_interval=0.2,
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# It should not have produced a result
assert task.done() is False

p.touch()

await asyncio.sleep(0.5)
assert p.exists() is False

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
22 changes: 18 additions & 4 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from airflow.models.asset import AssetModel
from airflow.serialization.serialized_objects import SerializedAssetWatcher
from airflow.triggers.base import BaseTrigger
from airflow.triggers.base import BaseEventTrigger

AttrsInstance = attrs.AttrsInstance
else:
Expand Down Expand Up @@ -254,7 +254,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
raise NotImplementedError


@attrs.define(frozen=True)
@attrs.define(init=False)
class AssetWatcher:
"""A representation of an asset watcher. The name uniquely identifies the watch."""

Expand All @@ -263,8 +263,22 @@ class AssetWatcher:
# For a "normal" asset instance loaded from DAG, this holds the trigger used to monitor an external
# resource. In that case, ``AssetWatcher`` is used directly by users.
# For an asset recreated from a serialized DAG, this holds the serialized data of the trigger. In that
# case, `SerializedAssetWatcher` is used. We need to keep the two types to make mypy happy.
trigger: BaseTrigger | dict
# case, `SerializedAssetWatcher` is used. We need to keep the two types to make mypy happy because
# `SerializedAssetWatcher` is a subclass of `AssetWatcher`.
trigger: BaseEventTrigger | dict

def __init__(
self,
name: str,
trigger: BaseEventTrigger | dict,
) -> None:
from airflow.triggers.base import BaseEventTrigger, BaseTrigger

if isinstance(trigger, BaseTrigger) and not isinstance(trigger, BaseEventTrigger):
raise ValueError("The trigger used to watch an asset must inherit ``BaseEventTrigger``")

self.name = name
self.trigger = trigger


@attrs.define(init=False, unsafe_hash=False)
Expand Down
4 changes: 2 additions & 2 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey, AssetWatcher
from airflow.sdk.definitions.param import Param
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
Expand Down Expand Up @@ -259,7 +259,7 @@ def __len__(self) -> int:
Asset(
uri="test://asset1",
name="test",
watchers=[AssetWatcher(name="test", trigger=FileTrigger(filepath="/tmp"))],
watchers=[AssetWatcher(name="test", trigger=FileDeleteTrigger(filepath="/tmp"))],
),
DAT.ASSET,
equals,
Expand Down
Loading