Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
map_index_template=partial_kwargs.pop("map_index_template", None),
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def partial(**kwargs):
return self.class_method.__get__(cls, cls)


_PARTIAL_DEFAULTS = {
_PARTIAL_DEFAULTS: dict[str, Any] = {
"map_index_template": None,
"owner": DEFAULT_OWNER,
"trigger_rule": DEFAULT_TRIGGER_RULE,
"depends_on_past": False,
Expand Down
53 changes: 31 additions & 22 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

if TYPE_CHECKING:
import datetime
from typing import List

import jinja2 # Slow import.
import pendulum
Expand All @@ -83,6 +84,8 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule

TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]]

ValidationSource = Union[Literal["expand"], Literal["partial"]]


Expand Down Expand Up @@ -211,7 +214,6 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
expand_input=expand_input,
partial_kwargs=partial_kwargs,
task_id=task_id,
map_index_template=partial_kwargs.pop("map_index_template", None),
params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
Expand Down Expand Up @@ -281,7 +283,6 @@ class MappedOperator(AbstractOperator):
end_date: pendulum.DateTime | None
upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
map_index_template: str | None

_disallow_kwargs_override: bool
"""Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
Expand Down Expand Up @@ -392,6 +393,14 @@ def owner(self) -> str: # type: ignore[override]
def email(self) -> None | str | Iterable[str]:
return self.partial_kwargs.get("email")

@property
def map_index_template(self) -> None | str:
return self.partial_kwargs.get("map_index_template")

@map_index_template.setter
def map_index_template(self, value: str | None) -> None:
self.partial_kwargs["map_index_template"] = value

@property
def trigger_rule(self) -> TriggerRule:
return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
Expand Down Expand Up @@ -453,35 +462,35 @@ def wait_for_downstream(self, value: bool) -> None:
self.partial_kwargs["wait_for_downstream"] = value

@property
def retries(self) -> int | None:
def retries(self) -> int:
return self.partial_kwargs.get("retries", DEFAULT_RETRIES)

@retries.setter
def retries(self, value: int | None) -> None:
def retries(self, value: int) -> None:
self.partial_kwargs["retries"] = value

@property
def queue(self) -> str:
return self.partial_kwargs.get("queue", DEFAULT_QUEUE)

@queue.setter
def queue(self, value: str | None) -> None:
def queue(self, value: str) -> None:
self.partial_kwargs["queue"] = value

@property
def pool(self) -> str:
return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)

@pool.setter
def pool(self, value: str | None) -> None:
def pool(self, value: str) -> None:
self.partial_kwargs["pool"] = value

@property
def pool_slots(self) -> str | None:
def pool_slots(self) -> int:
return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)

@pool_slots.setter
def pool_slots(self, value: str | None) -> None:
def pool_slots(self, value: int) -> None:
self.partial_kwargs["pool_slots"] = value

@property
Expand All @@ -505,31 +514,31 @@ def retry_delay(self) -> datetime.timedelta:
return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)

@retry_delay.setter
def retry_delay(self, value: datetime.timedelta | None) -> None:
def retry_delay(self, value: datetime.timedelta) -> None:
self.partial_kwargs["retry_delay"] = value

@property
def retry_exponential_backoff(self) -> bool:
return bool(self.partial_kwargs.get("retry_exponential_backoff"))

@retry_exponential_backoff.setter
def retry_exponential_backoff(self, value: bool | None) -> None:
def retry_exponential_backoff(self, value: bool) -> None:
self.partial_kwargs["retry_exponential_backoff"] = value

@property
def priority_weight(self) -> int: # type: ignore[override]
return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)

@priority_weight.setter
def priority_weight(self, value: int | None) -> None:
def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value

@property
def weight_rule(self) -> str: # type: ignore[override]
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)

@weight_rule.setter
def weight_rule(self, value: str | None) -> None:
def weight_rule(self, value: str) -> None:
self.partial_kwargs["weight_rule"] = value

@property
Expand Down Expand Up @@ -561,43 +570,43 @@ def resources(self) -> Resources | None:
return self.partial_kwargs.get("resources")

@property
def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_execute_callback")

@on_execute_callback.setter
def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_execute_callback"] = value

@property
def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_failure_callback")

@on_failure_callback.setter
def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_failure_callback"] = value

@property
def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_retry_callback")

@on_retry_callback.setter
def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_retry_callback"] = value

@property
def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_success_callback")

@on_success_callback.setter
def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_success_callback"] = value

@property
def on_skipped_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_skipped_callback")

@on_skipped_callback.setter
def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_skipped_callback"] = value

@property
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,6 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
task_group=None,
start_date=None,
end_date=None,
map_index_template=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
)
Expand Down