Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import lazy_object_proxy

from airflow import settings
from airflow.cli.commands.legacy_commands import check_legacy_command
from airflow.configuration import conf
from airflow.utils.cli import ColorMode
Expand Down Expand Up @@ -157,15 +156,6 @@ def string_lower_type(val):
help="The logical date of the DAG or run_id of the DAGRun (optional)",
)
ARG_TASK_REGEX = Arg(("-t", "--task-regex"), help="The regex to filter specific task_ids (optional)")
ARG_SUBDIR = Arg(
("-S", "--subdir"),
help=(
"File location or directory from which to look for the dag. "
"Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the "
"value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' "
),
default="[AIRFLOW_HOME]/dags" if BUILD_DOCS else settings.DAGS_FOLDER,
)
ARG_BUNDLE_NAME = Arg(
(
"-B",
Expand Down Expand Up @@ -1176,7 +1166,7 @@ class GroupCommand(NamedTuple):
name="list",
help="List the tasks within a DAG",
func=lazy_load_command("airflow.cli.commands.task_command.task_list"),
args=(ARG_DAG_ID, ARG_SUBDIR, ARG_VERBOSE),
args=(ARG_DAG_ID, ARG_BUNDLE_NAME, ARG_VERBOSE),
),
ActionCommand(
name="clear",
Expand All @@ -1187,7 +1177,7 @@ class GroupCommand(NamedTuple):
ARG_TASK_REGEX,
ARG_START_DATE,
ARG_END_DATE,
ARG_SUBDIR,
ARG_BUNDLE_NAME,
ARG_UPSTREAM,
ARG_DOWNSTREAM,
ARG_YES,
Expand All @@ -1205,7 +1195,7 @@ class GroupCommand(NamedTuple):
ARG_DAG_ID,
ARG_TASK_ID,
ARG_LOGICAL_DATE_OR_RUN_ID,
ARG_SUBDIR,
ARG_BUNDLE_NAME,
ARG_VERBOSE,
ARG_MAP_INDEX,
),
Expand All @@ -1219,7 +1209,14 @@ class GroupCommand(NamedTuple):
"and then run by an executor."
),
func=lazy_load_command("airflow.cli.commands.task_command.task_failed_deps"),
args=(ARG_DAG_ID, ARG_TASK_ID, ARG_LOGICAL_DATE_OR_RUN_ID, ARG_SUBDIR, ARG_MAP_INDEX, ARG_VERBOSE),
args=(
ARG_DAG_ID,
ARG_TASK_ID,
ARG_LOGICAL_DATE_OR_RUN_ID,
ARG_BUNDLE_NAME,
ARG_MAP_INDEX,
ARG_VERBOSE,
),
),
ActionCommand(
name="render",
Expand All @@ -1229,7 +1226,7 @@ class GroupCommand(NamedTuple):
ARG_DAG_ID,
ARG_TASK_ID,
ARG_LOGICAL_DATE_OR_RUN_ID,
ARG_SUBDIR,
ARG_BUNDLE_NAME,
ARG_VERBOSE,
ARG_MAP_INDEX,
),
Expand All @@ -1246,7 +1243,7 @@ class GroupCommand(NamedTuple):
ARG_DAG_ID,
ARG_TASK_ID,
ARG_LOGICAL_DATE_OR_RUN_ID_OPTIONAL,
ARG_SUBDIR,
ARG_BUNDLE_NAME,
ARG_DRY_RUN,
ARG_TASK_PARAMS,
ARG_POST_MORTEM,
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def dag_dependencies_show(args) -> None:
@providers_configuration_loaded
def dag_show(args) -> None:
"""Display DAG or saves its graphic representation to the file."""
dag = get_dag(subdir=None, dag_id=args.dag_id, from_db=True)
dag = get_dag(bundle_names=None, dag_id=args.dag_id, from_db=True)
dot = render_dag(dag)
filename = args.save
imgcat = args.imgcat
Expand Down Expand Up @@ -354,7 +354,7 @@ def dag_next_execution(args) -> None:
>>> airflow dags next-execution tutorial
2018-08-31 10:38:00
"""
dag = get_dag(subdir=None, dag_id=args.dag_id, from_db=True)
dag = get_dag(bundle_names=None, dag_id=args.dag_id, from_db=True)

with create_session() as session:
last_parsed_dag: DagModel = session.scalars(
Expand Down Expand Up @@ -610,7 +610,7 @@ def _render_dagrun(dr: DagRun) -> dict[str, str]:

def _parse_and_get_dag(dag_id: str) -> DAG | None:
"""Given a dag_id, determine the bundle and relative fileloc from the db, then parse and return the DAG."""
db_dag = get_dag(subdir=None, dag_id=dag_id, from_db=True)
db_dag = get_dag(bundle_names=None, dag_id=dag_id, from_db=True)
bundle_name = db_dag.get_bundle_name()
if bundle_name is None:
raise AirflowException(
Expand Down
15 changes: 8 additions & 7 deletions airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def task_failed_deps(args) -> None:
Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks
to have succeeded, but found 1 non-success(es).
"""
dag = get_dag(args.subdir, args.dag_id)
dag = get_dag(args.bundle_name, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
Expand All @@ -255,7 +255,7 @@ def task_state(args) -> None:
>>> airflow tasks state tutorial sleep 2015-01-01
success
"""
dag = get_dag(args.subdir, args.dag_id, from_db=True)
dag = get_dag(args.bundle_name, args.dag_id, from_db=True)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id)
print(ti.state)
Expand All @@ -266,7 +266,7 @@ def task_state(args) -> None:
@providers_configuration_loaded
def task_list(args, dag: DAG | None = None) -> None:
"""List the tasks within a DAG at the command line."""
dag = dag or get_dag(args.subdir, args.dag_id)
dag = dag or get_dag(args.bundle_name, args.dag_id)
tasks = sorted(t.task_id for t in dag.tasks)
print("\n".join(tasks))

Expand Down Expand Up @@ -365,7 +365,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
env_vars.update(args.env_vars)
os.environ.update(env_vars)

dag = dag or get_dag(args.subdir, args.dag_id)
dag = dag or get_dag(args.bundle_name, args.dag_id)

dag = DAG.from_sdk_dag(dag)

Expand Down Expand Up @@ -424,7 +424,8 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N
def task_render(args, dag: DAG | None = None) -> None:
"""Render and displays templated fields for a given task."""
if not dag:
dag = get_dag(args.subdir, args.dag_id)
dag = get_dag(args.bundle_name, args.dag_id)
dag = DAG.from_sdk_dag(dag)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(
task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory"
Expand All @@ -448,12 +449,12 @@ def task_render(args, dag: DAG | None = None) -> None:
def task_clear(args) -> None:
"""Clear all task instances or only those matched by regex for a DAG(s)."""
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
if args.dag_id and not args.subdir and not args.dag_regex and not args.task_regex:
if args.dag_id and not args.bundle_name and not args.dag_regex and not args.task_regex:
dags = [get_dag_by_file_location(args.dag_id)]
else:
# todo clear command only accepts a single dag_id. no reason for get_dags with 's' except regex?
# Reading from_db because clear method still not implemented in Task SDK DAG
dags = get_dags(args.subdir, args.dag_id, use_regex=args.dag_regex, from_db=True)
dags = get_dags(args.bundle_name, args.dag_id, use_regex=args.dag_regex, from_db=True)

if args.task_regex:
for idx, dag in enumerate(dags):
Expand Down
59 changes: 43 additions & 16 deletions airflow-core/src/airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,48 +258,75 @@ def _search_for_dag_file(val: str | None) -> str | None:
return None


def get_dag(subdir: str | None, dag_id: str, from_db: bool = False) -> DAG:
def get_dag(bundle_names: list | None, dag_id: str, from_db: bool = False) -> DAG:
"""
Return DAG of a given dag_id.

First we'll try to use the given subdir. If that doesn't work, we'll try to
find the correct path (assuming it's a file) and failing that, use the configured
dags folder.
"""
from airflow.models.dag import DAG
from airflow.models.dagbag import DagBag

bundle_names = bundle_names or []
dag: DAG | None = None

if from_db:
dagbag = DagBag(read_dags_from_db=True)
dag = dagbag.get_dag(dag_id) # get_dag loads from the DB as requested
else:
first_path = process_subdir(subdir)
dagbag = DagBag(first_path)
dag = dagbag.dags.get(dag_id) # avoids db calls made in get_dag
# Create a SchedulerDAG since some of the CLI commands rely on DB access
dag = DAG.from_sdk_dag(dag)
elif bundle_names:
manager = DagBundlesManager()
for bundle_name in bundle_names:
bundle = manager.get_bundle(bundle_name)
dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path)
dag = dagbag.dags.get(dag_id)
if dag:
break
if not dag:
if from_db:
raise AirflowException(f"Dag {dag_id!r} could not be found in DagBag read from database.")
fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
dagbag = DagBag(dag_folder=fallback_path)
dag = dagbag.get_dag(dag_id)
manager = DagBundlesManager()
all_bundles = list(manager.get_all_dag_bundles())
for bundle in all_bundles:
dag_bag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path)
dag = dag_bag.dags.get(dag_id)
if dag:
break
if not dag:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
)
return dag


def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False, from_db: bool = False):
def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, from_db: bool = False):
"""Return DAG(s) matching a given regex or dag_id."""
from airflow.models import DagBag

bundle_names = bundle_names or []

if not use_regex:
return [get_dag(subdir=subdir, dag_id=dag_id, from_db=from_db)]
dagbag = DagBag(process_subdir(subdir))
matched_dags = [dag for dag in dagbag.dags.values() if re.search(dag_id, dag.dag_id)]
return [get_dag(bundle_names=bundle_names, dag_id=dag_id, from_db=from_db)]

def _find_dag(bundle):
dagbag = DagBag(dag_folder=bundle.path, bundle_path=bundle.path)
matched_dags = [dag for dag in dagbag.dags.values() if re.search(dag_id, dag.dag_id)]
return matched_dags

manager = DagBundlesManager()
matched_dags = []
for bundle_name in bundle_names:
bundle = manager.get_bundle(bundle_name)
matched_dags = _find_dag(bundle)
if matched_dags:
break
if not matched_dags:
# Search in all bundles
all_bundles = list(manager.get_all_dag_bundles())
for bundle in all_bundles:
matched_dags = _find_dag(bundle)
if matched_dags:
break
if not matched_dags:
raise AirflowException(
f"dag_id could not be found with regex: {dag_id}. Either the dag did not exist or "
Expand Down
7 changes: 3 additions & 4 deletions airflow-core/tests/unit/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,10 @@ def test_cli_test_with_env_vars(self):

@mock.patch("airflow.providers.standard.triggers.file.os.path.getmtime", return_value=0)
@mock.patch("airflow.providers.standard.triggers.file.glob", return_value=["/tmp/test"])
@mock.patch("airflow.providers.standard.triggers.file.os.path.isfile", return_value=True)
@mock.patch("airflow.providers.standard.triggers.file.os")
@mock.patch("airflow.providers.standard.sensors.filesystem.FileSensor.poke", return_value=False)
def test_cli_test_with_deferrable_operator(
self, mock_pock, mock_is_file, mock_glob, mock_getmtime, caplog
):
def test_cli_test_with_deferrable_operator(self, mock_pock, mock_os, mock_glob, mock_getmtime, caplog):
mock_os.path.isfile.return_value = True
with caplog.at_level(level=logging.INFO):
task_command.task_test(
self.parser.parse_args(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
def generate_pod_yaml(args):
"""Generate yaml files for each task in the DAG. Used for testing output of KubernetesExecutor."""
logical_date = args.logical_date if AIRFLOW_V_3_0_PLUS else args.execution_date
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
if AIRFLOW_V_3_0_PLUS:
dag = get_dag(bundle_names=args.bundle_name, dag_id=args.dag_id)
else:
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
yaml_output_path = args.output_path
if AIRFLOW_V_3_0_PLUS:
dr = DagRun(dag.dag_id, logical_date=logical_date)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from airflow.cli.cli_config import (
ARG_DAG_ID,
ARG_OUTPUT_PATH,
ARG_SUBDIR,
ARG_VERBOSE,
ActionCommand,
Arg,
Expand Down Expand Up @@ -94,6 +93,16 @@
AirflowKubernetesScheduler,
)


if AIRFLOW_V_3_0_PLUS:
from airflow.cli.cli_config import ARG_BUNDLE_NAME

ARG_COMPAT = ARG_BUNDLE_NAME
else:
from airflow.cli.cli_config import ARG_SUBDIR # type: ignore[attr-defined]

ARG_COMPAT = ARG_SUBDIR

# CLI Args
ARG_NAMESPACE = Arg(
("--namespace",),
Expand Down Expand Up @@ -128,7 +137,7 @@
help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without "
"launching into a cluster",
func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.generate_pod_yaml"),
args=(ARG_DAG_ID, ARG_LOGICAL_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH, ARG_VERBOSE),
args=(ARG_DAG_ID, ARG_LOGICAL_DATE, ARG_COMPAT, ARG_OUTPUT_PATH, ARG_VERBOSE),
),
)

Expand Down
Loading