Skip to content

Commit

Permalink
Dag test without sensor (#40010)
Browse files Browse the repository at this point in the history
  • Loading branch information
jannisko authored Jun 16, 2024
1 parent 518a9e4 commit 60c2d36
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 5 deletions.
8 changes: 8 additions & 0 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ def string_lower_type(val):
"If set, it uses the executor configured in the environment.",
action="store_true",
)
ARG_MARK_SUCCESS_PATTERN = Arg(
("--mark-success-pattern",),
help=(
"Don't run task_ids matching the regex <MARK_SUCCESS_PATTERN>, mark them as successful instead.\n"
"Can be used to skip e.g. dependency check sensors or cleanup steps in local testing.\n"
),
)

# list_tasks
ARG_TREE = Arg(("-t", "--tree"), help="Tree view", action="store_true")
Expand Down Expand Up @@ -1288,6 +1295,7 @@ class GroupCommand(NamedTuple):
ARG_SAVE_DAGRUN,
ARG_USE_EXECUTOR,
ARG_VERBOSE,
ARG_MARK_SUCCESS_PATTERN,
),
),
ActionCommand(
Expand Down
12 changes: 11 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import warnings
from typing import TYPE_CHECKING

import re2
from sqlalchemy import delete, select

from airflow import settings
Expand Down Expand Up @@ -606,10 +607,19 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
raise SystemExit(f"Configuration {args.conf!r} is not valid JSON. Error: {e}")
execution_date = args.execution_date or timezone.utcnow()
use_executor = args.use_executor

mark_success_pattern = (
re2.compile(args.mark_success_pattern) if args.mark_success_pattern is not None else None
)

with _airflow_parsing_context_manager(dag_id=args.dag_id):
dag = dag or get_dag(subdir=args.subdir, dag_id=args.dag_id)
dr: DagRun = dag.test(
execution_date=execution_date, run_conf=run_conf, use_executor=use_executor, session=session
execution_date=execution_date,
run_conf=run_conf,
use_executor=use_executor,
mark_success_pattern=mark_success_pattern,
session=session,
)
show_dagrun = args.show_dagrun
imgcat = args.imgcat_dagrun
Expand Down
21 changes: 18 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,7 @@ def test(
conn_file_path: str | None = None,
variable_file_path: str | None = None,
use_executor: bool = False,
mark_success_pattern: Pattern | str | None = None,
session: Session = NEW_SESSION,
) -> DagRun:
"""
Expand All @@ -2896,6 +2897,7 @@ def test(
:param conn_file_path: file path to a connection file in either yaml or json
:param variable_file_path: file path to a variable file in either yaml or json
:param use_executor: if set, uses an executor to test the DAG
:param mark_success_pattern: regex of task_ids to mark as success instead of running
:param session: database connection (optional)
"""

Expand Down Expand Up @@ -2983,6 +2985,12 @@ def add_logger_if_needed(ti: TaskInstance):
for ti in scheduled_tis:
ti.task = tasks[ti.task_id]

mark_success = (
re2.compile(mark_success_pattern).fullmatch(ti.task_id) is not None
if mark_success_pattern is not None
else False
)

if use_executor:
if executor.has_task(ti):
continue
Expand All @@ -2992,7 +3000,12 @@ def add_logger_if_needed(ti: TaskInstance):
# Run the task locally
try:
add_logger_if_needed(ti)
_run_task(ti=ti, inline_trigger=not triggerer_running, session=session)
_run_task(
ti=ti,
inline_trigger=not triggerer_running,
session=session,
mark_success=mark_success,
)
except Exception:
self.log.exception("Task failed; ti=%s", ti)
if use_executor:
Expand Down Expand Up @@ -4209,7 +4222,9 @@ async def _run_inline_trigger_main():
return asyncio.run(_run_inline_trigger_main())


def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
def _run_task(
*, ti: TaskInstance, inline_trigger: bool = False, mark_success: bool = False, session: Session
):
"""
Run a single task instance, and push result to Xcom for downstream tasks.
Expand All @@ -4223,7 +4238,7 @@ def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Sessio
while True:
try:
log.info("[DAG TEST] running task %s", ti)
ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
ti._run_raw_task(session=session, raise_on_defer=inline_trigger, mark_success=mark_success)
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
Expand Down
25 changes: 25 additions & 0 deletions docs/apache-airflow/core-concepts/debug.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ needed. Here are some examples of arguments:
executor, it just runs all the tasks locally.
By providing this argument, the DAG is executed using the executor configured in the Airflow environment.

Conditionally skipping tasks
----------------------------

If you don't wish to execute some subset of tasks in your local environment (e.g. dependency check sensors or cleanup steps),
you can automatically mark them successful supplying a pattern matching their ``task_id`` in the ``mark_success_pattern`` argument.

In the following example, testing the dag won't wait for either of the upstream dags to complete. Instead, testing data
is manually ingested. The cleanup step is also skipped, making the intermediate csv is available for inspection.

.. code-block:: python
with DAG("example_dag", default_args=default_args) as dag:
sensor = ExternalTaskSensor(task_id="wait_for_ingestion_dag", external_dag_id="ingest_raw_data")
sensor2 = ExternalTaskSensor(task_id="wait_for_dim_dag", external_dag_id="ingest_dim")
collect_stats = PythonOperator(task_id="extract_stats_csv", python_callable=extract_stats_csv)
# ... run other tasks
cleanup = PythonOperator(task_id="cleanup", python_callable=Path.unlink, op_args=[collect_stats.output])
[sensor, sensor2] >> collect_stats >> cleanup
if __name__ == "__main__":
ingest_testing_data()
run = dag.test(mark_success_pattern="wait_for_.*|cleanup")
print(f"Intermediate csv: {run.get_task_instance('collect_stats').xcom_pull(task_id='collect_stats')}")
Comparison with DebugExecutor
-----------------------------

Expand Down
30 changes: 29 additions & 1 deletion tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ def test_dag_test(self, mock_get_dag):
run_conf=None,
use_executor=False,
session=mock.ANY,
mark_success_pattern=None,
),
]
)
Expand Down Expand Up @@ -907,7 +908,11 @@ def test_dag_test_no_execution_date(self, mock_utcnow, mock_get_dag):
[
mock.call(subdir=cli_args.subdir, dag_id="example_bash_operator"),
mock.call().test(
execution_date=mock.ANY, run_conf=None, use_executor=False, session=mock.ANY
execution_date=mock.ANY,
run_conf=None,
use_executor=False,
session=mock.ANY,
mark_success_pattern=None,
),
]
)
Expand All @@ -934,6 +939,7 @@ def test_dag_test_conf(self, mock_get_dag):
run_conf={"dag_run_conf_param": "param_value"},
use_executor=False,
session=mock.ANY,
mark_success_pattern=None,
),
]
)
Expand All @@ -957,6 +963,7 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_render_dag):
run_conf=None,
use_executor=False,
session=mock.ANY,
mark_success_pattern=None,
),
]
)
Expand Down Expand Up @@ -1033,3 +1040,24 @@ def execute(self, context, event=None):
assert mock_run.call_args_list[0] == ((trigger,), {})
tis = dr.get_task_instances()
assert next(x for x in tis if x.task_id == "abc").state == "success"

@mock.patch("airflow.models.taskinstance.TaskInstance._execute_task_with_callbacks")
def test_dag_test_with_mark_success(self, mock__execute_task_with_callbacks):
"""
option `--mark-success-pattern` should mark matching tasks as success without executing them.
"""
cli_args = self.parser.parse_args(
[
"dags",
"test",
"example_sensor_decorator",
datetime(2024, 1, 1, 0, 0, 0).isoformat(),
"--mark-success-pattern",
"wait_for_upstream",
]
)
dag_command.dag_test(cli_args)

# only second operator was actually executed, first one was marked as success
assert len(mock__execute_task_with_callbacks.call_args_list) == 1
assert mock__execute_task_with_callbacks.call_args_list[0].kwargs["self"].task_id == "dummy_operator"

0 comments on commit 60c2d36

Please sign in to comment.