-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dag_integrity.py
83 lines (64 loc) · 2.33 KB
/
test_dag_integrity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Test the validity of all DAGs. This test ensures that all Dags have tags, retries set to two, and no import errors. Feel free to add and remove tests."""
import os
import logging
from contextlib import contextmanager
import pytest
from airflow.models import DagBag
@contextmanager
def suppress_logging(namespace):
logger = logging.getLogger(namespace)
old_value = logger.disabled
logger.disabled = True
try:
yield
finally:
logger.disabled = old_value
def get_import_errors():
"""
Generate a tuple for import errors in the dag bag
"""
with suppress_logging("airflow"):
dag_bag = DagBag(include_examples=False)
def strip_path_prefix(path):
return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))
# prepend "(None,None)" to ensure that a test object is always created even if it's a no op.
return [(None, None)] + [
(strip_path_prefix(k), v.strip()) for k, v in dag_bag.import_errors.items()
]
def get_dags():
"""
Generate a tuple of dag_id, <DAG objects> in the DagBag
"""
with suppress_logging("airflow"):
dag_bag = DagBag(include_examples=False)
def strip_path_prefix(path):
return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))
return [(k, v, strip_path_prefix(v.fileloc)) for k, v in dag_bag.dags.items()]
@pytest.mark.parametrize(
"rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()]
)
def test_file_imports(rel_path, rv):
"""Test for import errors on a file"""
if rel_path and rv:
raise Exception(f"{rel_path} failed to import with message \n {rv}")
APPROVED_TAGS = {}
@pytest.mark.parametrize(
"dag_id,dag,fileloc", get_dags(), ids=[x[2] for x in get_dags()]
)
def test_dag_tags(dag_id, dag, fileloc):
"""
test if a DAG is tagged and if those TAGs are in the approved list
"""
assert dag.tags, f"{dag_id} in {fileloc} has no tags"
if APPROVED_TAGS:
assert not set(dag.tags) - APPROVED_TAGS
@pytest.mark.parametrize(
"dag_id,dag, fileloc", get_dags(), ids=[x[2] for x in get_dags()]
)
def test_dag_retries(dag_id, dag, fileloc):
"""
test if a DAG has retries set
"""
assert (
dag.default_args.get("retries", None) >= 2
), f"{dag_id} in {fileloc} does not have retries not set to 2."