Skip to content
7 changes: 5 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def _get_email_subject_content(

else:
from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.utils.context import context_merge
from airflow.sdk.definitions.context import Context

if TYPE_CHECKING:
assert task_instance.task
Expand All @@ -381,7 +381,10 @@ def _get_email_subject_content(
else:
jinja_env = SandboxedEnvironment(cache_size=0)
jinja_context = task_instance.get_template_context()
context_merge(jinja_context, additional_context)
if not jinja_context:
jinja_context = Context()
# Add additional fields to the context for email template rendering
jinja_context.update(additional_context) # type: ignore[typeddict-item]

def render(key: str, content: str) -> str:
if conf.has_option("email", key):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from airflow.providers.cncf.kubernetes.version_compat import (
DecoratedOperator,
TaskDecorator,
context_merge,
task_decorator_factory,
)
from airflow.utils.context import context_merge
from airflow.utils.operator_helpers import determine_kwargs

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
if AIRFLOW_V_3_1_PLUS:
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.sdk import BaseHook, BaseOperator
from airflow.sdk.definitions.context import context_merge
else:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
from airflow.models import BaseOperator
from airflow.utils.context import context_merge # type: ignore[attr-defined, no-redef]
from airflow.utils.xcom import XCOM_RETURN_KEY # type: ignore[no-redef]

if AIRFLOW_V_3_0_PLUS:
Expand All @@ -64,4 +66,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"TaskDecorator",
"task_decorator_factory",
"XCOM_RETURN_KEY",
"context_merge",
]
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
)

from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.version_compat import context_merge
from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
from airflow.utils.context import context_merge
from airflow.utils.operator_helpers import determine_kwargs

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@
from airflow.models.variable import Variable
from airflow.providers.standard.hooks.package_index import PackageIndexHook
from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator, context_merge
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
Expand Down Expand Up @@ -487,7 +486,8 @@ def _iter_serializable_context_keys(self):

def execute(self, context: Context) -> Any:
serializable_keys = set(self._iter_serializable_context_keys())
serializable_context = context_copy_partial(context, serializable_keys)
new = {k: v for k, v in context.items() if k in serializable_keys}
serializable_context = cast("Context", new)
return super().execute(context=serializable_context)

def get_python_source(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any

from airflow.providers.standard.version_compat import BaseSensorOperator, PokeReturnValue
from airflow.utils.context import context_merge
from airflow.providers.standard.version_compat import BaseSensorOperator, PokeReturnValue, context_merge
from airflow.utils.operator_helpers import determine_kwargs

if TYPE_CHECKING:
try:
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context
from airflow.utils.context import Context # type: ignore[no-redef, attr-defined]


class PythonSensor(BaseSensorOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
# even though it wasn't used.
if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseHook, BaseOperator
from airflow.sdk.definitions.context import context_merge
else:
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef]
from airflow.utils.context import context_merge # type: ignore[no-redef, attr-defined]

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperatorLink
Expand All @@ -59,4 +61,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"BaseHook",
"BaseSensorOperator",
"PokeReturnValue",
"context_merge",
]
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/bases/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@
)
from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.context import KNOWN_CONTEXT_KEYS
from airflow.sdk.definitions.mappedoperator import (
MappedOperator,
ensure_xcomarg_return_value,
prevent_duplicates,
)
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.utils.context import KNOWN_CONTEXT_KEYS
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/bases/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING

from airflow.sdk.definitions._internal.templater import Templater
from airflow.utils.context import context_merge
from airflow.sdk.definitions.context import context_merge
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
Expand Down
22 changes: 22 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ class Context(TypedDict, total=False):
var: Any


KNOWN_CONTEXT_KEYS: set[str] = set(Context.__annotations__.keys())


def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
"""
Merge parameters into an existing context.

Like ``dict.update()`` , this take the same parameters, and updates
``context`` in-place.

This is implemented as a free function because the ``Context`` type is
"faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
functions.

:meta private:
"""
if not context:
context = Context()

context.update(*args, **kwargs)


def get_current_context() -> Context:
"""
Retrieve the execution context dictionary without altering user method's signature.
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.operatorlink import BaseOperatorLink
from airflow.sdk.definitions._internal.expandinput import ExpandInput
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.trigger_rule import TriggerRule

Expand Down
Loading