Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
from airflow.utils.context import Context, context_get_outlet_events
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
from airflow.utils.helpers import validate_instance_args, validate_key
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -521,6 +521,47 @@ def __new__(cls, name, bases, namespace, **kwargs):
return new_cls


# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the
# correct type. This is a temporary solution until we find a more sophisticated method for argument
# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not
# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
# version that supports `get_type_hints` effectively or find a better approach, we can replace this
# manual type-checking method.
BASEOPERATOR_ARGS_EXPECTED_TYPES = {
"task_id": str,
"email": (str, Iterable),
"email_on_retry": bool,
"email_on_failure": bool,
"retries": int,
"retry_exponential_backoff": bool,
"depends_on_past": bool,
"ignore_first_depends_on_past": bool,
"wait_for_past_depends_before_skipping": bool,
"wait_for_downstream": bool,
"priority_weight": int,
"queue": str,
"pool": str,
"pool_slots": int,
"trigger_rule": str,
"run_as_user": str,
"task_concurrency": int,
"map_index_template": str,
"max_active_tis_per_dag": int,
"max_active_tis_per_dagrun": int,
"executor": str,
"do_xcom_push": bool,
"multiple_outputs": bool,
"doc": str,
"doc_md": str,
"doc_json": str,
"doc_yaml": str,
"doc_rst": str,
"task_display_name": str,
"logger_name": str,
"allow_nested_operators": bool,
}


@total_ordering
class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
r"""
Expand Down Expand Up @@ -1078,6 +1119,8 @@ def __init__(
if SetupTeardownContext.active:
SetupTeardownContext.update_context_map(self)

validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)

def __eq__(self, other):
if type(self) is type(other):
# Use getattr() instead of __dict__ as __dict__ doesn't return
Expand Down
30 changes: 29 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.dates import cron_presets, date_range as utils_date_range
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.helpers import at_most_one, exactly_one, validate_key
from airflow.utils.helpers import at_most_one, exactly_one, validate_instance_args, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import (
Expand Down Expand Up @@ -341,6 +341,32 @@ def _create_orm_dagrun(
return run


# TODO: The following mapping is used to validate that the arguments passed to the DAG are of the correct
# type. This is a temporary solution until we find a more sophisticated method for argument validation.
# One potential method is to use `get_type_hints` from the typing module. However, this is not fully
# compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
# version that supports `get_type_hints` effectively or find a better approach, we can replace this
# manual type-checking method.
DAG_ARGS_EXPECTED_TYPES = {
"dag_id": str,
"description": str,
"max_active_tasks": int,
"max_active_runs": int,
"max_consecutive_failed_dag_runs": int,
"dagrun_timeout": timedelta,
"default_view": str,
"orientation": str,
"catchup": bool,
"doc_md": str,
"is_paused_upon_creation": bool,
"render_template_as_native_obj": bool,
"tags": list,
"auto_register": bool,
"fail_stop": bool,
"dag_display_name": str,
}


@functools.total_ordering
class DAG(LoggingMixin):
"""
Expand Down Expand Up @@ -744,6 +770,8 @@ def __init__(
# fileloc based only on the serialize dag
self._processor_dags_folder = None

validate_instance_args(self, DAG_ARGS_EXPECTED_TYPES)

def get_doc_md(self, doc_md: str | None) -> str | None:
if doc_md is None:
return doc_md
Expand Down
11 changes: 11 additions & 0 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def validate_key(k: str, max_length: int = 250):
)


def validate_instance_args(instance: object, expected_arg_types: dict[str, Any]) -> None:
"""Validate that the instance has the expected types for the arguments."""
for arg_name, expected_arg_type in expected_arg_types.items():
instance_arg_value = getattr(instance, arg_name, None)
if instance_arg_value is not None and not isinstance(instance_arg_value, expected_arg_type):
raise TypeError(
f"'{arg_name}' has an invalid type {type(instance_arg_value)} with value "
f"{instance_arg_value}, expected type is {expected_arg_type}"
)


def validate_group_key(k: str, max_length: int = 200):
"""Validate value used as a group key."""
if not isinstance(k, str):
Expand Down
19 changes: 18 additions & 1 deletion airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from airflow.models.taskmixin import DAGNode
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
from airflow.utils.helpers import validate_group_key, validate_instance_args

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand All @@ -49,6 +49,21 @@
from airflow.models.taskmixin import DependencyMixin
from airflow.utils.edgemodifier import EdgeModifier

# TODO: The following mapping is used to validate that the arguments passed to the TaskGroup are of the
# correct type. This is a temporary solution until we find a more sophisticated method for argument
# validation. One potential method is to use get_type_hints from the typing module. However, this is not
# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python
# version that supports `get_type_hints` effectively or find a better approach, we can replace this
# manual type-checking method.
TASKGROUP_ARGS_EXPECTED_TYPES = {
"group_id": str,
"prefix_group_id": bool,
"tooltip": str,
"ui_color": str,
"ui_fgcolor": str,
"add_suffix_on_collision": bool,
}


class TaskGroup(DAGNode):
"""
Expand Down Expand Up @@ -160,6 +175,8 @@ def __init__(
self.upstream_task_ids = set()
self.downstream_task_ids = set()

validate_instance_args(self, TASKGROUP_ARGS_EXPECTED_TYPES)

def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
if self._group_id is None:
return
Expand Down
17 changes: 17 additions & 0 deletions tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule, RemovedInAirflow3Warning
from airflow.lineage.entities import File
from airflow.models.baseoperator import (
BASEOPERATOR_ARGS_EXPECTED_TYPES,
BaseOperator,
BaseOperatorMeta,
chain,
Expand Down Expand Up @@ -811,6 +812,22 @@ def test_logging_propogated_by_default(self, caplog):
# leaking a lot of state)
assert caplog.messages == ["test"]

def test_invalid_type_for_default_arg(self):
error_msg = "'max_active_tis_per_dag' has an invalid type <class 'str'> with value not_an_int, expected type is <class 'int'>"
with pytest.raises(TypeError, match=error_msg):
BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"})

def test_invalid_type_for_operator_arg(self):
error_msg = "'max_active_tis_per_dag' has an invalid type <class 'str'> with value not_an_int, expected type is <class 'int'>"
with pytest.raises(TypeError, match=error_msg):
BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int")

@mock.patch("airflow.models.baseoperator.validate_instance_args")
def test_baseoperator_init_validates_arg_types(self, mock_validate_instance_args):
operator = BaseOperator(task_id="test")

mock_validate_instance_args.assert_called_once_with(operator, BASEOPERATOR_ARGS_EXPECTED_TYPES)


def test_init_subclass_args():
class InitSubclassOp(BaseOperator):
Expand Down
13 changes: 13 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import (
DAG,
DAG_ARGS_EXPECTED_TYPES,
DagModel,
DagOwnerAttributes,
DagTag,
Expand Down Expand Up @@ -3928,6 +3929,18 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR
)


def test_invalid_type_for_args():
with pytest.raises(TypeError):
DAG("invalid-default-args", max_consecutive_failed_dag_runs="not_an_int")


@mock.patch("airflow.models.dag.validate_instance_args")
def test_dag_init_validates_arg_types(mock_validate_instance_args):
dag = DAG("dag_with_expected_args")

mock_validate_instance_args.assert_called_once_with(dag, DAG_ARGS_EXPECTED_TYPES)


class TestTaskClearingSetupTeardownBehavior:
"""
Task clearing behavior is mainly controlled by dag.partial_subset.
Expand Down
34 changes: 34 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
merge_dicts,
prune_dict,
validate_group_key,
validate_instance_args,
validate_key,
)
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -355,3 +356,36 @@ class SchedulerJobRunner(MockJobRunner):

class TriggererJobRunner(MockJobRunner):
job_type = "TriggererJob"


class ClassToValidateArgs:
def __init__(self, name, age, active):
self.name = name
self.age = age
self.active = active


# Edge cases
@pytest.mark.parametrize(
"instance, expected_arg_types",
[
(ClassToValidateArgs("Alice", 30, None), {"name": str, "age": int, "active": bool}),
(ClassToValidateArgs(None, 25, True), {"name": str, "age": int, "active": bool}),
],
)
def test_validate_instance_args_raises_no_error(instance, expected_arg_types):
validate_instance_args(instance, expected_arg_types)


# Error cases
@pytest.mark.parametrize(
"instance, expected_arg_types",
[
(ClassToValidateArgs("Alice", "thirty", True), {"name": str, "age": int, "active": bool}),
(ClassToValidateArgs("Bob", 25, "yes"), {"name": str, "age": int, "active": bool}),
(ClassToValidateArgs(123, 25, True), {"name": str, "age": int, "active": bool}),
],
)
def test_validate_instance_args_raises_error(instance, expected_arg_types):
with pytest.raises(TypeError):
validate_instance_args(instance, expected_arg_types)
20 changes: 19 additions & 1 deletion tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from datetime import timedelta
from unittest import mock

import pendulum
import pytest
Expand All @@ -37,7 +38,7 @@
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.utils.dag_edges import dag_edges
from airflow.utils.task_group import TaskGroup, task_group_to_dict
from airflow.utils.task_group import TASKGROUP_ARGS_EXPECTED_TYPES, TaskGroup, task_group_to_dict
from tests.models import DEFAULT_DATE


Expand Down Expand Up @@ -1630,3 +1631,20 @@ def work(): ...
assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"}
assert set(t1.operator.downstream_task_ids) == set()
assert set(t2.operator.downstream_task_ids) == set()


def test_task_group_with_invalid_arg_type_raises_error():
error_msg = "'ui_color' has an invalid type <class 'int'> with value 123, expected type is <class 'str'>"
with DAG(dag_id="dag_with_tg_invalid_arg_type"):
with pytest.raises(TypeError, match=error_msg):
with TaskGroup("group_1", ui_color=123):
EmptyOperator(task_id="task1")


@mock.patch("airflow.utils.task_group.validate_instance_args")
def test_task_group_init_validates_arg_types(mock_validate_instance_args):
with DAG(dag_id="dag_with_tg_valid_arg_types"):
with TaskGroup("group_1", ui_color="red") as tg:
EmptyOperator(task_id="task1")

mock_validate_instance_args.assert_called_with(tg, TASKGROUP_ARGS_EXPECTED_TYPES)