Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ def string_lower_type(val):
action="store_true",
)

# list_dags
ARG_LIST_LOCAL = Arg(
("-l", "--local"),
action="store_true",
help="Shows local parsed DAGs and their import errors, ignores content serialized in DB",
)

# list_dag_runs
ARG_NO_BACKFILL = Arg(
("--no-backfill",), help="filter all the backfill dagruns given the dag id", action="store_true"
Expand Down Expand Up @@ -959,13 +966,13 @@ class GroupCommand(NamedTuple):
name="list",
help="List all the DAGs",
func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_dags"),
args=(ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS, ARG_BUNDLE_NAME),
args=(ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS, ARG_BUNDLE_NAME, ARG_LIST_LOCAL),
),
ActionCommand(
name="list-import-errors",
help="List all the DAGs that have import errors",
func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_import_errors"),
args=(ARG_BUNDLE_NAME, ARG_OUTPUT, ARG_VERBOSE),
args=(ARG_BUNDLE_NAME, ARG_OUTPUT, ARG_VERBOSE, ARG_LIST_LOCAL),
),
ActionCommand(
name="report",
Expand Down
96 changes: 69 additions & 27 deletions airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ def _get_dagbag_dag_details(dag: DAG) -> dict:
return {
"dag_id": dag.dag_id,
"dag_display_name": dag.dag_display_name,
"bundle_name": dag.get_bundle_name(),
"bundle_version": dag.get_bundle_version(),
"is_paused": dag.get_is_paused(),
"is_stale": dag.get_is_stale(),
"bundle_name": dag.get_bundle_name() if hasattr(dag, "get_bundle_name") else None,
"bundle_version": dag.get_bundle_version() if hasattr(dag, "get_bundle_version") else None,
"is_paused": dag.get_is_paused() if hasattr(dag, "get_is_paused") else None,
"is_stale": dag.get_is_stale() if hasattr(dag, "get_is_stale") else None,
"last_parsed_time": None,
"last_expired": None,
"fileloc": dag.fileloc,
Expand Down Expand Up @@ -404,15 +404,38 @@ def dag_list_dags(args, session: Session = NEW_SESSION) -> None:
file=sys.stderr,
)

dagbag = DagBag(read_dags_from_db=True)
dagbag.collect_dags_from_db()
dagbag_import_errors = 0
dags_list = []
if args.local:
# Get import errors from the local area
if args.bundle_name:
manager = DagBundlesManager()
validate_dag_bundle_arg(args.bundle_name)
all_bundles = list(manager.get_all_dag_bundles())
bundles_to_search = set(args.bundle_name)

for bundle in all_bundles:
if bundle.name in bundles_to_search:
dagbag = DagBag(bundle.path, bundle_path=bundle.path)
dagbag.collect_dags()
dags_list.extend(list(dagbag.dags.values()))
dagbag_import_errors += len(dagbag.import_errors)
else:
dagbag = DagBag()
dagbag.collect_dags()
dags_list.extend(list(dagbag.dags.values()))
dagbag_import_errors += len(dagbag.import_errors)
else:
# Get import errors from the DB
dagbag = DagBag(read_dags_from_db=True)
dagbag.collect_dags_from_db()
dags_list = list(dagbag.dags.values())

# Get import errors from the DB
query = select(func.count()).select_from(ParseImportError)
if args.bundle_name:
query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))
query = select(func.count()).select_from(ParseImportError)
if args.bundle_name:
query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))

dagbag_import_errors = session.scalar(query)
dagbag_import_errors = session.scalar(query)

if dagbag_import_errors > 0:
from rich import print as rich_print
Expand Down Expand Up @@ -441,7 +464,7 @@ def filter_dags_by_bundle(dags: list[DAG], bundle_names: list[str] | None) -> li

AirflowConsole().print_as(
data=sorted(
filter_dags_by_bundle(list(dagbag.dags.values()), args.bundle_name),
filter_dags_by_bundle(dags_list, args.bundle_name if not args.local else None),
key=operator.attrgetter("dag_id"),
),
output=args.output,
Expand Down Expand Up @@ -479,22 +502,41 @@ def dag_list_import_errors(args, session: Session = NEW_SESSION) -> None:
"""Display dags with import errors on the command line."""
data = []

# Get import errors from the DB
query = select(ParseImportError)
if args.bundle_name:
validate_dag_bundle_arg(args.bundle_name)
query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))

dagbag_import_errors = session.scalars(query).all()
if args.local:
# Get import errors from local areas
if args.bundle_name:
manager = DagBundlesManager()
validate_dag_bundle_arg(args.bundle_name)
all_bundles = list(manager.get_all_dag_bundles())
bundles_to_search = set(args.bundle_name)

for bundle in all_bundles:
if bundle.name in bundles_to_search:
dagbag = DagBag(bundle.path, bundle_path=bundle.path)
for filename, errors in dagbag.import_errors.items():
data.append({"bundle_name": bundle.name, "filepath": filename, "error": errors})
else:
dagbag = DagBag()
for filename, errors in dagbag.import_errors.items():
data.append({"filepath": filename, "error": errors})

for import_error in dagbag_import_errors:
data.append(
{
"bundle_name": import_error.bundle_name,
"filepath": import_error.filename,
"error": import_error.stacktrace,
}
)
else:
# Get import errors from the DB
query = select(ParseImportError)
if args.bundle_name:
validate_dag_bundle_arg(args.bundle_name)
query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))

dagbag_import_errors = session.scalars(query).all()

for import_error in dagbag_import_errors:
data.append(
{
"bundle_name": import_error.bundle_name,
"filepath": import_error.filename,
"error": import_error.stacktrace,
}
)
AirflowConsole().print_as(
data=data,
output=args.output,
Expand Down
56 changes: 56 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,42 @@ def test_cli_list_dags(self):
assert key in dag_list[0]
assert any("airflow/example_dags/example_complex.py" in d["fileloc"] for d in dag_list)

@conf_vars({("core", "load_examples"): "true"})
def test_cli_list_local_dags(self):
# Clear the database
clear_db_dags()
args = self.parser.parse_args(["dags", "list", "--output", "json", "--local"])
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
dag_command.dag_list_dags(args)
out = temp_stdout.getvalue()
dag_list = json.loads(out)
for key in ["dag_id", "fileloc", "owners", "is_paused"]:
assert key in dag_list[0]
assert any("airflow/example_dags/example_complex.py" in d["fileloc"] for d in dag_list)
# Rebuild Test DB for other tests
parse_and_sync_to_db(os.devnull, include_examples=True)

@conf_vars({("core", "load_examples"): "false"})
def test_cli_list_local_dags_with_bundle_name(self, configure_testing_dag_bundle):
# Clear the database
clear_db_dags()
path_to_parse = TEST_DAGS_FOLDER / "test_example_bash_operator.py"
args = self.parser.parse_args(
["dags", "list", "--output", "json", "--local", "--bundle-name", "testing"]
)
with configure_testing_dag_bundle(path_to_parse):
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
dag_command.dag_list_dags(args)
out = temp_stdout.getvalue()
dag_list = json.loads(out)
for key in ["dag_id", "fileloc", "owners", "is_paused"]:
assert key in dag_list[0]
assert any(
str(TEST_DAGS_FOLDER / "test_example_bash_operator.py") in d["fileloc"] for d in dag_list
)
# Rebuild Test DB for other tests
parse_and_sync_to_db(os.devnull, include_examples=True)

@conf_vars({("core", "load_examples"): "true"})
def test_cli_list_dags_custom_cols(self):
args = self.parser.parse_args(
Expand Down Expand Up @@ -294,6 +330,26 @@ def test_cli_list_dags_prints_import_errors(self, configure_testing_dag_bundle,

assert "Failed to load all files." in out

@conf_vars({("core", "load_examples"): "false"})
def test_cli_list_dags_prints_local_import_errors(self, configure_testing_dag_bundle, get_test_dag):
# Clear the database
clear_db_dags()
path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py"
get_test_dag("test_invalid_cron")

args = self.parser.parse_args(
["dags", "list", "--output", "yaml", "--bundle-name", "testing", "--local"]
)

with configure_testing_dag_bundle(path_to_parse):
with contextlib.redirect_stderr(StringIO()) as temp_stderr:
dag_command.dag_list_dags(args)
out = temp_stderr.getvalue()

assert "Failed to load all files." in out
# Rebuild Test DB for other tests
parse_and_sync_to_db(os.devnull, include_examples=True)

@conf_vars({("core", "load_examples"): "true"})
@mock.patch("airflow.models.DagModel.get_dagmodel")
def test_list_dags_none_get_dagmodel(self, mock_get_dagmodel):
Expand Down