Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow-core/src/airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -84,6 +84,8 @@ class DagCallbackRequest(BaseCallbackRequest):
run_id: str
is_failure_callback: bool | None = True
"""Flag to determine whether it is a Failure Callback or Success Callback"""
dag_run: dict[str, Any] | None = None
"""Serialized dag_run information to be included in the callback context"""
type: Literal["DagCallbackRequest"] = "DagCallbackRequest"


Expand Down
5 changes: 4 additions & 1 deletion airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,16 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
return

callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
# TODO:We need a proper context object!
context: Context = {
"dag": dag,
"run_id": request.run_id,
"reason": request.msg,
}

# Only add dag_run to context if it's provided
if request.dag_run is not None:
context["dag_run"] = request.dag_run

for callback in callbacks:
log.info(
"Executing on_%s dag callback",
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ def _schedule_dag_run(
bundle_version=dag_run.bundle_version,
is_failure_callback=True,
msg="timed_out",
dag_run=dag_run.serialize_for_callback(),
)

dag_run.notify_dagrun_state_changed()
Expand Down
27 changes: 27 additions & 0 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ def recalculate(self) -> _UnfinishedStates:
bundle_version=self.bundle_version,
is_failure_callback=True,
msg="task_failure",
dag_run=self.serialize_for_callback(),
)

# Check if the max_consecutive_failed_dag_runs has been provided and not 0
Expand Down Expand Up @@ -1217,6 +1218,7 @@ def recalculate(self) -> _UnfinishedStates:
bundle_version=self.bundle_version,
is_failure_callback=False,
msg="success",
dag_run=self.serialize_for_callback(),
)

if (deadline := dag.deadline) and isinstance(deadline.reference, DeadlineReference.TYPES.DAGRUN):
Expand All @@ -1240,6 +1242,7 @@ def recalculate(self) -> _UnfinishedStates:
bundle_version=self.bundle_version,
is_failure_callback=True,
msg="all_tasks_deadlocked",
dag_run=self.serialize_for_callback(),
)

# finally, if the leaves aren't done, the dag is still running
Expand Down Expand Up @@ -1356,6 +1359,7 @@ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: str = "
"dag": dag,
"run_id": str(self.run_id),
"reason": reason,
"dag_run": self,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't going to fix it. As the docstring says, this function is only used in dag.test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check #53058 for similar pattern

This comment was marked as outdated.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally I managed to get the dag_run information without querying from the database and pass it on until I can set it in the callback context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, although since we need more keys, I was working on it in parallel since it became a blocker for some. Check #53684

}

callbacks = dag.on_success_callback if success else dag.on_failure_callback
Expand Down Expand Up @@ -2013,6 +2017,29 @@ def _get_log_template(log_template_id: int | None, session: Session = NEW_SESSIO
def _get_partial_task_ids(dag: DAG | None) -> list[str] | None:
return dag.task_ids if dag and dag.partial else None

def serialize_for_callback(self) -> dict[str, Any]:
"""
Serialize DagRun object into a dictionary for callback requests.

This method creates a serialized representation of the DagRun that can be
safely passed to subprocesses without requiring database access.

:return: Dictionary containing serialized DagRun information
"""
return {
"dag_id": self.dag_id,
"run_id": self.run_id,
"state": self.state,
"logical_date": self.logical_date.isoformat() if self.logical_date else None,
"start_date": self.start_date.isoformat() if self.start_date else None,
"end_date": self.end_date.isoformat() if self.end_date else None,
"conf": self.conf,
"run_type": self.run_type,
"run_after": self.run_after.isoformat() if self.run_after else None,
"data_interval_start": self.data_interval_start.isoformat() if self.data_interval_start else None,
"data_interval_end": self.data_interval_end.isoformat() if self.data_interval_end else None,
}

Comment on lines +2020 to +2042
Copy link
Member

@kaxil kaxil Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this, check #53684

That should take care of that and restore the Airflow 2 behavior

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean that the entire PR is no longer needed because it is being addressed in #53684 or just that part of the code you highlighted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entire PR. And I do apologize though about the time spent here. I can certainly say this will be useful context for future contributions. I do really appreciate you taking the time and hope you will continue contributing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing this issue!


class DagRunNote(Base):
"""For storage of arbitrary notes concerning the dagrun instance."""
Expand Down
266 changes: 266 additions & 0 deletions airflow-core/tests/unit/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,269 @@ def test_is_failure_callback_property(
)

assert request.is_failure_callback == expected_is_failure


class TestDagCallbackRequest:
"""Test the DagCallbackRequest class with the new dag_run field."""

def test_dag_callback_request_with_dag_run(self):
"""Test DagCallbackRequest creation with dag_run field."""
dag_run_data = {
"dag_id": "test_dag",
"run_id": "test_run_2024-01-01T00:00:00+00:00",
"state": "success",
"logical_date": "2024-01-01T00:00:00+00:00",
"start_date": "2024-01-01T00:00:00+00:00",
"end_date": "2024-01-01T01:00:00+00:00",
"run_type": "manual",
"run_after": "2024-01-01T00:00:00+00:00",
"conf": {"key": "value"},
"data_interval_start": "2024-01-01T00:00:00+00:00",
"data_interval_end": "2024-01-01T01:00:00+00:00",
}

request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run_2024-01-01T00:00:00+00:00",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=dag_run_data,
)

assert request.dag_run == dag_run_data
assert request.dag_run["dag_id"] == "test_dag"
assert request.dag_run["run_id"] == "test_run_2024-01-01T00:00:00+00:00"
assert request.dag_run["state"] == "success"
assert request.dag_run["conf"]["key"] == "value"

def test_dag_callback_request_without_dag_run(self):
"""Test DagCallbackRequest creation without dag_run field (backward compatibility)."""
request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
)

assert request.dag_run is None
assert request.dag_id == "test_dag"
assert request.run_id == "test_run"

def test_dag_callback_request_serialization_with_dag_run(self):
"""Test DagCallbackRequest serialization and deserialization with dag_run field."""
dag_run_data = {
"dag_id": "test_dag",
"run_id": "test_run_2024-01-01T00:00:00+00:00",
"state": "success",
"logical_date": "2024-01-01T00:00:00+00:00",
"start_date": "2024-01-01T00:00:00+00:00",
"end_date": "2024-01-01T01:00:00+00:00",
"run_type": "manual",
"run_after": "2024-01-01T00:00:00+00:00",
"conf": {"key": "value", "nested": {"inner": "data"}},
"data_interval_start": "2024-01-01T00:00:00+00:00",
"data_interval_end": "2024-01-01T01:00:00+00:00",
}

original_request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run_2024-01-01T00:00:00+00:00",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=dag_run_data,
)

# Serialize to JSON
json_str = original_request.to_json()

# Deserialize from JSON
deserialized_request = DagCallbackRequest.from_json(json_str)

# Verify all fields are preserved
assert deserialized_request == original_request
assert deserialized_request.dag_run == dag_run_data
assert deserialized_request.dag_run["conf"]["nested"]["inner"] == "data"

def test_dag_callback_request_serialization_without_dag_run(self):
"""Test DagCallbackRequest serialization and deserialization without dag_run field."""
original_request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=True,
bundle_name="testing",
bundle_version=None,
msg="task_failure",
)

# Serialize to JSON
json_str = original_request.to_json()

# Deserialize from JSON
deserialized_request = DagCallbackRequest.from_json(json_str)

# Verify all fields are preserved
assert deserialized_request == original_request
assert deserialized_request.dag_run is None

def test_dag_callback_request_with_none_dag_run(self):
"""Test DagCallbackRequest with explicitly None dag_run field."""
request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=None,
)

assert request.dag_run is None

def test_dag_callback_request_with_minimal_dag_run_data(self):
"""Test DagCallbackRequest with minimal dag_run data."""
minimal_dag_run_data = {
"dag_id": "test_dag",
"run_id": "test_run",
"state": "success",
}

request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=minimal_dag_run_data,
)

assert request.dag_run == minimal_dag_run_data
assert request.dag_run["dag_id"] == "test_dag"
assert request.dag_run["run_id"] == "test_run"
assert request.dag_run["state"] == "success"

def test_dag_callback_request_with_null_values_in_dag_run(self):
"""Test DagCallbackRequest with null values in dag_run data."""
dag_run_data_with_nulls = {
"dag_id": "test_dag",
"run_id": "test_run",
"state": "success",
"logical_date": None,
"start_date": "2024-01-01T00:00:00+00:00",
"end_date": None,
"run_type": "manual",
"run_after": "2024-01-01T00:00:00+00:00",
"conf": None,
"data_interval_start": None,
"data_interval_end": None,
}

request = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=dag_run_data_with_nulls,
)

assert request.dag_run == dag_run_data_with_nulls
assert request.dag_run["logical_date"] is None
assert request.dag_run["end_date"] is None
assert request.dag_run["conf"] is None

def test_dag_callback_request_equality_with_dag_run(self):
"""Test DagCallbackRequest equality comparison with dag_run field."""
dag_run_data = {
"dag_id": "test_dag",
"run_id": "test_run",
"state": "success",
"conf": {"key": "value"},
}

request1 = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=dag_run_data,
)

request2 = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run=dag_run_data,
)

request3 = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run={"dag_id": "different_dag", "run_id": "test_run", "state": "success"},
)

assert request1 == request2
assert request1 != request3
assert request2 != request3

def test_dag_callback_request_equality_without_dag_run(self):
"""Test DagCallbackRequest equality comparison without dag_run field."""
request1 = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
)

request2 = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
)

request3 = DagCallbackRequest(
filepath="test_dag.py",
dag_id="test_dag",
run_id="test_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
msg="success",
dag_run={"dag_id": "test_dag", "run_id": "test_run"},
)

assert request1 == request2
assert request1 != request3 # Different because one has dag_run and the other doesn't
Loading
Loading