Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,11 @@ def _deserialize_field_value(cls, field_name: str, value: Any) -> Any:
elif field_name == "resources":
return Resources.from_dict(value) if value is not None else None
elif field_name.endswith("_date"):
# Check if value is ARG_NOT_SET before trying to deserialize as datetime
if isinstance(value, dict) and value.get(Encoding.TYPE) == DAT.ARG_NOT_SET:
from airflow.serialization.definitions.notset import NOTSET

return NOTSET
return cls._deserialize_datetime(value) if value is not None else None
else:
# For all other fields, return as-is (strings, ints, bools, etc.)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,32 @@ async def test_running(self, mock_check_for_blob, blob_exists):
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.get_async_conn")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_blob_async")
async def test_success(self, mock_check_for_blob):
async def test_success(self, mock_check_for_blob, mock_get_async_conn):
"""Tests the success state for that the WasbBlobSensorTrigger."""
# Mock get_async_conn to return an async context manager that doesn't block
mock_conn = mock.AsyncMock()
mock_get_async_conn.return_value = mock_conn
mock_check_for_blob.return_value = True

task = asyncio.create_task(self.TRIGGER.run().__anext__())
await asyncio.sleep(0.5)
# Wait for the task to complete with a timeout
try:
await asyncio.wait_for(task, timeout=2.0)
except asyncio.TimeoutError:
# If task didn't complete, check if it's done
if not task.done():
pytest.fail("Task did not complete within timeout")
finally:
try:
asyncio.get_event_loop().stop()
except RuntimeError:
# Event loop may already be stopped
pass

# TriggerEvent was returned
assert task.done() is True
asyncio.get_event_loop().stop()

message = f"Blob {TEST_DATA_STORAGE_BLOB_NAME} found in container {TEST_DATA_STORAGE_CONTAINER_NAME}."
assert task.result() == TriggerEvent({"status": "success", "message": message})
Expand Down Expand Up @@ -182,17 +197,32 @@ async def test_running(self, mock_check_for_prefix, prefix_exists):
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.get_async_conn")
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbAsyncHook.check_for_prefix_async")
async def test_success(self, mock_check_for_prefix):
async def test_success(self, mock_check_for_prefix, mock_get_async_conn):
"""Tests the success state for that the WasbPrefixSensorTrigger."""
# Mock get_async_conn to return an async context manager that doesn't block
mock_conn = mock.AsyncMock()
mock_get_async_conn.return_value = mock_conn
mock_check_for_prefix.return_value = True

task = asyncio.create_task(self.TRIGGER.run().__anext__())
await asyncio.sleep(0.5)
# Wait for the task to complete with a timeout
try:
await asyncio.wait_for(task, timeout=2.0)
except asyncio.TimeoutError:
# If task didn't complete, check if it's done
if not task.done():
pytest.fail("Task did not complete within timeout")
finally:
try:
asyncio.get_event_loop().stop()
except RuntimeError:
# Event loop may already be stopped
pass

# TriggerEvent was returned
assert task.done() is True
asyncio.get_event_loop().stop()

message = (
f"Prefix {TEST_DATA_STORAGE_BLOB_PREFIX} found in container {TEST_DATA_STORAGE_CONTAINER_NAME}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"
XCOM_DAG_ID = "trigger_dag_id"


if TYPE_CHECKING:
Expand Down Expand Up @@ -86,12 +87,17 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
if TYPE_CHECKING:
assert isinstance(operator, TriggerDagRunOperator)

trigger_dag_id = operator.trigger_dag_id
if not AIRFLOW_V_3_0_PLUS:
from airflow.models.renderedtifields import RenderedTaskInstanceFields
# Try to get the resolved dag_id from XCom first (for dynamic dag_ids)
trigger_dag_id = XCom.get_value(ti_key=ti_key, key=XCOM_DAG_ID)

if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef]
# Fallback to operator attribute and rendered fields if not in XCom
if not trigger_dag_id:
trigger_dag_id = operator.trigger_dag_id
if not AIRFLOW_V_3_0_PLUS:
from airflow.models.renderedtifields import RenderedTaskInstanceFields

if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef]

# Fetch the correct dag_run_id for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
Expand Down Expand Up @@ -195,7 +201,7 @@ def __init__(
self.openlineage_inject_parent_info = openlineage_inject_parent_info
self.deferrable = deferrable
self.logical_date = logical_date
if logical_date is NOTSET:
if isinstance(logical_date, ArgNotSet) or logical_date is NOTSET:
self.logical_date = NOTSET
elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)):
self.logical_date = logical_date
Expand All @@ -208,7 +214,7 @@ def __init__(
raise NotImplementedError("Setting `fail_when_dag_is_paused` not yet supported for Airflow 3.x")

def execute(self, context: Context):
if self.logical_date is NOTSET:
if isinstance(self.logical_date, ArgNotSet) or self.logical_date is NOTSET:
# If no logical_date is provided we will set utcnow()
parsed_logical_date = timezone.utcnow()
elif self.logical_date is None or isinstance(self.logical_date, datetime.datetime):
Expand Down Expand Up @@ -266,6 +272,14 @@ def execute(self, context: Context):
def _trigger_dag_af_3(self, context, run_id, parsed_logical_date):
from airflow.providers.common.compat.sdk import DagRunTriggerException

# Store the resolved dag_id to XCom for use in the link generation
# This is important for dynamic dag_ids (from XCom or complex templates)
# In Airflow 3.x, context has both "task_instance" and "ti" keys
if "task_instance" in context:
context["task_instance"].xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id)
elif "ti" in context:
context["ti"].xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id)

raise DagRunTriggerException(
trigger_dag_id=self.trigger_dag_id,
dag_run_id=run_id,
Expand Down Expand Up @@ -311,10 +325,11 @@ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date):
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
# Store the run id from the dag run (either created or found above) to
# Store the run id and dag_id from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
ti = context["task_instance"]
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)
ti.xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id)

if self.wait_for_completion:
# Kick off the deferral process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@ def test_trigger_dagrun(self):
assert task.trigger_run_id == expected_run_id # run_id is saved as attribute

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
@mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_one")
def test_extra_operator_link(self, mock_xcom_get_one, dag_maker):
@mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value")
def test_extra_operator_link(self, mock_xcom_get_value, dag_maker):
from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
Expand All @@ -153,14 +155,54 @@ def test_extra_operator_link(self, mock_xcom_get_one, dag_maker):
dr = dag_maker.create_dagrun(run_id="test_run_id")
ti = dr.get_task_instance(task_id=task.task_id)

mock_xcom_get_one.return_value = ti.run_id
# Mock XCom.get_value to return None for dag_id but return run_id for XCOM_RUN_ID
def mock_get_value(ti_key, key):
if key == XCOM_RUN_ID:
return "test_run_id"
return None

mock_xcom_get_value.side_effect = mock_get_value

link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key)

base_url = conf.get("api", "base_url", fallback="/").lower()
expected_url = f"{base_url}dags/{TRIGGERED_DAG_ID}/runs/test_run_id"
assert link == expected_url, f"Expected {expected_url}, but got {link}"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
@mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value")
def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker):
"""Test that operator link works correctly when dag_id is dynamically resolved from XCom."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
# In real scenario, this would be a template like "{{ ti.xcom_pull(...) }}"
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="test_run_id",
)

dr = dag_maker.create_dagrun(run_id="test_run_id")
ti = dr.get_task_instance(task_id=task.task_id)

# Mock XCom.get_value to return our test values
def mock_get_value(ti_key, key):
if key == XCOM_DAG_ID:
return "dynamic_dag_id"
if key == XCOM_RUN_ID:
return "dynamic_run_id"
return None

mock_xcom_get_value.side_effect = mock_get_value

link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key)

base_url = conf.get("api", "base_url", fallback="/").lower()
# Should use the dag_id from XCom, not the operator attribute
expected_url = f"{base_url}dags/dynamic_dag_id/runs/dynamic_run_id"
assert link == expected_url, f"Expected {expected_url}, but got {link}"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dagrun_custom_run_id(self):
task = TriggerDagRunOperator(
Expand All @@ -174,6 +216,38 @@ def test_trigger_dagrun_custom_run_id(self):

assert exc_info.value.dag_run_id == "custom_run_id"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dagrun_pushes_dag_id_to_xcom(self, dag_maker):
"""Test that TriggerDagRunOperator pushes the resolved dag_id to XCom during execution."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
)

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id=task.task_id)

# Create a mock task instance that stores XCom values
xcom_values = {}

def mock_xcom_push(key, value, **kwargs):
xcom_values[key] = value

ti.xcom_push = mock_xcom_push

# Execute the task (will raise exception in AF3, but should push XCom first)
try:
task.execute(context={"task_instance": ti})
except DagRunTriggerException:
pass # Expected in Airflow 3

# Verify that the dag_id was pushed to XCom
assert XCOM_DAG_ID in xcom_values
assert xcom_values[XCOM_DAG_ID] == TRIGGERED_DAG_ID

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dagrun_with_logical_date(self):
"""Test TriggerDagRunOperator with custom logical_date."""
Expand Down Expand Up @@ -577,8 +651,37 @@ def test_explicitly_provided_trigger_run_id_is_saved_as_attr(self, dag_maker, se

assert task.trigger_run_id == "test_run_id"

def test_extra_operator_link(self, dag_maker, session):
def test_trigger_dagrun_pushes_dag_id_to_xcom(self, dag_maker, session):
"""Test that TriggerDagRunOperator pushes the resolved dag_id to XCom during execution."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="test_run_id",
)
dag_maker.create_dagrun()
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

triggering_ti = session.scalar(
select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id)
)
assert triggering_ti is not None

# Verify that the dag_id was pushed to XCom
dag_id_xcom = triggering_ti.xcom_pull(key=XCOM_DAG_ID)
assert dag_id_xcom == TRIGGERED_DAG_ID

# Also verify run_id is still pushed
run_id_xcom = triggering_ti.xcom_pull(key=XCOM_RUN_ID)
assert run_id_xcom == "test_run_id"

@mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value")
def test_extra_operator_link(self, mock_xcom_get_value, dag_maker, session):
"""Asserts whether the correct extra links url will be created."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="test_run_id"
Expand All @@ -592,6 +695,14 @@ def test_extra_operator_link(self, dag_maker, session):
)
)

# Mock XCom.get_value to return None for dag_id but return run_id for XCOM_RUN_ID
def mock_get_value(ti_key, key):
if key == XCOM_RUN_ID:
return "test_run_id"
return None

mock_xcom_get_value.side_effect = mock_get_value

with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url:
# This is equivalent of a task run calling this and pushing to xcom
task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key)
Expand All @@ -603,6 +714,47 @@ def test_extra_operator_link(self, dag_maker, session):
}
assert expected_args in args

@mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value")
def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker, session):
"""Test that operator link works correctly when dag_id is dynamically resolved from XCom."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
# In real scenario, this would be a template like "{{ ti.xcom_pull(...) }}"
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="test_run_id",
)
dag_maker.create_dagrun()
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

triggering_ti = session.scalar(
select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id)
)
assert triggering_ti is not None

# Mock XCom.get_value to return our test values
def mock_get_value(ti_key, key):
if key == XCOM_DAG_ID:
return "dynamic_dag_id"
if key == XCOM_RUN_ID:
return "test_run_id"
return None

mock_xcom_get_value.side_effect = mock_get_value

with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url:
task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key)
assert mock_build_url.called
args, _ = mock_build_url.call_args
# Should use the dag_id from XCom, not the operator attribute
expected_args = {
"dag_id": "dynamic_dag_id",
"dag_run_id": "test_run_id",
}
assert expected_args in args

def test_trigger_dagrun_with_logical_date(self, dag_maker):
"""Test TriggerDagRunOperator with custom logical_date."""
custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5)
Expand Down
Loading
Loading