Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
from collections.abc import MutableMapping
from functools import cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from airflow.api_fastapi.core_api.datamodels.connections import (
ConnectionHookFieldBehavior,
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(
description: str = "",
default: str | None = None,
widget=None,
source: Literal["dag", "task"] | None = None,
):
type: str | list[str] = [self.param_type, "null"]
enum = {}
Expand All @@ -82,6 +83,7 @@ def __init__(
default=default,
title=label,
description=description or None,
source=source or None,
type=type,
**format,
**enum,
Expand Down
12 changes: 10 additions & 2 deletions airflow-core/src/airflow/serialization/definitions/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import collections.abc
import copy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from airflow.serialization.definitions.notset import NOTSET, is_arg_set

Expand All @@ -31,11 +31,18 @@
class SerializedParam:
"""Server-side param class for deserialization."""

def __init__(self, default: Any = NOTSET, description: str | None = None, **schema):
def __init__(
self,
default: Any = NOTSET,
description: str | None = None,
source: Literal["dag", "task"] | None = None,
**schema,
):
# No validation needed - the SDK already validated the default.
self.value = default
self.description = description
self.schema = schema
self.source = source

def resolve(self, *, raises: bool = False) -> Any:
"""
Expand Down Expand Up @@ -66,6 +73,7 @@ def dump(self) -> dict[str, Any]:
"value": self.resolve(),
"schema": self.schema,
"description": self.description,
"source": self.source,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def _serialize_param(cls, param: Param):
"default": cls.serialize(param.value),
"description": cls.serialize(param.description),
"schema": cls.serialize(param.schema),
"source": cls.serialize(getattr(param, "source", None)),
}

@classmethod
Expand All @@ -1048,7 +1049,7 @@ def _deserialize_param(cls, param_dict: dict) -> SerializedParam:
this class's ``serialize`` method. So before running through ``deserialize``,
we first verify that it's necessary to do.
"""
attrs = ("default", "description", "schema")
attrs = ("default", "description", "schema", "source")
kwargs = {}

def is_serialized(val):
Expand All @@ -1068,6 +1069,7 @@ def is_serialized(val):
return SerializedParam(
default=kwargs.get("default"),
description=kwargs.get("description"),
source=kwargs.get("source", None),
**(kwargs.get("schema") or {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,14 @@ def test_dag_details(
"next_dagrun_run_after": None,
"owners": ["airflow"],
"owner_links": {},
"params": {"foo": {"value": 1, "schema": {}, "description": None}},
"params": {
"foo": {
"value": 1,
"schema": {},
"description": None,
"source": None,
}
},
"relative_fileloc": "test_dags.py",
"render_template_as_native_obj": False,
"timetable_summary": None,
Expand Down Expand Up @@ -1034,7 +1041,14 @@ def test_dag_details_with_view_url_template(
"next_dagrun_run_after": None,
"owners": ["airflow"],
"owner_links": {},
"params": {"foo": {"value": 1, "schema": {}, "description": None}},
"params": {
"foo": {
"value": 1,
"schema": {},
"description": None,
"source": None,
}
},
"relative_fileloc": "test_dags.py",
"render_template_as_native_obj": False,
"timetable_summary": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_should_respond_200(self, test_client):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
"params": {"foo": {"value": "bar", "schema": {}, "description": None}},
"params": {"foo": {"value": "bar", "schema": {}, "description": None, "source": "task"}},
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
Expand Down Expand Up @@ -180,7 +180,14 @@ def test_unscheduled_task(self, test_client):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
"params": {"is_unscheduled": {"value": True, "schema": {}, "description": None}},
"params": {
"is_unscheduled": {
"value": True,
"schema": {},
"description": None,
"source": "task",
}
},
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
Expand Down Expand Up @@ -239,7 +246,14 @@ def test_should_respond_200_serialized(self, test_client, testing_dag_bundle):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
"params": {"foo": {"value": "bar", "schema": {}, "description": None}},
"params": {
"foo": {
"value": "bar",
"schema": {},
"description": None,
"source": "task",
}
},
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
Expand Down Expand Up @@ -304,7 +318,14 @@ def test_should_respond_200(self, test_client):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
"params": {"foo": {"value": "bar", "schema": {}, "description": None}},
"params": {
"foo": {
"value": "bar",
"schema": {},
"description": None,
"source": "task",
}
},
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
Expand Down Expand Up @@ -459,7 +480,14 @@ def test_get_unscheduled_tasks(self, test_client):
"extra_links": [],
"operator_name": "EmptyOperator",
"owner": "airflow",
"params": {"is_unscheduled": {"value": True, "schema": {}, "description": None}},
"params": {
"is_unscheduled": {
"value": True,
"schema": {},
"description": None,
"source": "task",
}
},
"pool": "default_pool",
"pool_slots": 1.0,
"priority_weight": 1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def test_full_param_roundtrip(self, param: Param):
"value": None if param.value is NOTSET else param.value,
"schema": param.schema,
"description": param.description,
"source": None,
}

@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def __init__(
self.multiple = multiple

self.params: ParamsDict = params if isinstance(params, ParamsDict) else ParamsDict(params or {})
if hasattr(ParamsDict, "filter_params_by_source"):
# Params that exist only in Dag level does not make sense to appear in HITLOperator
self.params = ParamsDict.filter_params_by_source(self.params, source="task")
elif self.params:
self.log.debug(
"ParamsDict.filter_params_by_source not available; HITLOperator will also include Dag level params."
)

self.notifiers: Sequence[BaseNotifier] = (
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers or []
Expand Down
36 changes: 19 additions & 17 deletions providers/standard/tests/unit/standard/operators/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", allow_module_level=True)
Expand Down Expand Up @@ -240,19 +240,23 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
assert hitl_detail_model.responded_by is None
assert hitl_detail_model.chosen_options is None
assert hitl_detail_model.params_input == {}
if AIRFLOW_V_3_1_3_PLUS:
assert hitl_detail_model.params == {
"input_1": {
"value": 1,
"description": None,
"schema": {},
}
}
expected_params: dict[str, Any]
if AIRFLOW_V_3_2_PLUS:
expected_params = {"input_1": {"value": 1, "description": None, "schema": {}, "source": "task"}}
elif AIRFLOW_V_3_1_3_PLUS:
expected_params = {"input_1": {"value": 1, "description": None, "schema": {}}}
else:
assert hitl_detail_model.params == {"input_1": 1}
expected_params = {"input_1": 1}
assert hitl_detail_model.params == expected_params

assert notifier.called is True

expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
if AIRFLOW_V_3_2_PLUS:
expected_params_in_trigger_kwargs = expected_params
else:
expected_params_in_trigger_kwargs = {"input_1": {"value": 1, "description": None, "schema": {}}}

registered_trigger = session.scalar(
select(Trigger).where(Trigger.classpath == "airflow.providers.standard.triggers.hitl.HITLTrigger")
)
Expand All @@ -261,13 +265,7 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
"ti_id": ti.id,
"options": ["1", "2", "3", "4", "5"],
"defaults": ["1"],
"params": {
"input_1": {
"value": 1,
"description": None,
"schema": {},
}
},
"params": expected_params_in_trigger_kwargs,
"multiple": False,
"timeout_datetime": None,
"poke_interval": 5.0,
Expand Down Expand Up @@ -323,6 +321,10 @@ def test_serialzed_params(
options=["1", "2", "3", "4", "5"],
params=input_params,
)
if AIRFLOW_V_3_2_PLUS:
for key in expected_params:
expected_params[key]["source"] = "task"

assert hitl_op.serialized_params == expected_params

@pytest.mark.skipif(
Expand Down
20 changes: 17 additions & 3 deletions providers/standard/tests/unit/standard/triggers/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop public API compatible with Airflow >= 3.1.0", allow_module_level=True)
Expand Down Expand Up @@ -50,7 +50,12 @@ def default_trigger_args() -> dict[str, Any]:
"ti_id": TI_ID,
"options": ["1", "2", "3", "4", "5"],
"params": {
"input": {"value": 1, "schema": {}, "description": None},
"input": {
"value": 1,
"schema": {},
"description": None,
"source": "task",
},
},
"multiple": False,
}
Expand All @@ -65,11 +70,20 @@ def test_serialization(self, default_trigger_args):
**default_trigger_args,
)
classpath, kwargs = trigger.serialize()

expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
if AIRFLOW_V_3_2_PLUS:
expected_params_in_trigger_kwargs = {
"input": {"value": 1, "description": None, "schema": {}, "source": "task"}
}
else:
expected_params_in_trigger_kwargs = {"input": {"value": 1, "description": None, "schema": {}}}

assert classpath == "airflow.providers.standard.triggers.hitl.HITLTrigger"
assert kwargs == {
"ti_id": TI_ID,
"options": ["1", "2", "3", "4", "5"],
"params": {"input": {"value": 1, "description": None, "schema": {}}},
"params": expected_params_in_trigger_kwargs,
"defaults": ["1"],
"multiple": False,
"timeout_datetime": None,
Expand Down
10 changes: 9 additions & 1 deletion task-sdk/src/airflow/sdk/bases/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple
return {}, ParamsDict()
dag_args = copy.copy(dag.default_args)
dag_params = copy.deepcopy(dag.params)
dag_params._fill_missing_param_source("dag")
if task_group:
if task_group.default_args and not isinstance(task_group.default_args, collections.abc.Mapping):
raise TypeError("default_args must be a mapping")
Expand All @@ -155,13 +156,20 @@ def get_merged_defaults(
if task_params:
if not isinstance(task_params, collections.abc.Mapping):
raise TypeError(f"params must be a mapping, got {type(task_params)}")

task_params = ParamsDict(task_params)
task_params._fill_missing_param_source("task")
params.update(task_params)

if task_default_args:
if not isinstance(task_default_args, collections.abc.Mapping):
raise TypeError(f"default_args must be a mapping, got {type(task_params)}")
args.update(task_default_args)
with contextlib.suppress(KeyError):
params.update(task_default_args["params"] or {})
if params_from_default_args := ParamsDict(task_default_args["params"] or {}):
params_from_default_args._fill_missing_param_source("task")
params.update(params_from_default_args)

return args, params


Expand Down
Loading
Loading