Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,11 +1035,13 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
continue
serialized_op["partial_kwargs"].update({k: cls.serialize(v)})

# we want to store python_callable_name, not python_callable
# Store python_callable_name instead of python_callable.
# exclude_module=True ensures stable names across bundle version changes.
python_callable = op.partial_kwargs.get("python_callable", None)
if python_callable:
callable_name = qualname(python_callable)
serialized_op["partial_kwargs"]["python_callable_name"] = callable_name
serialized_op["partial_kwargs"]["python_callable_name"] = qualname(
python_callable, exclude_module=True
)
del serialized_op["partial_kwargs"]["python_callable"]

serialized_op["_is_mapped"] = True
Expand All @@ -1060,11 +1062,11 @@ def _serialize_node(cls, op: SdkOperator) -> dict[str, Any]:
if attr in serialize_op:
del serialize_op[attr]

# Detect if there's a change in python callable name
# Store python_callable_name for change detection.
# exclude_module=True ensures stable names across bundle version changes.
python_callable = getattr(op, "python_callable", None)
if python_callable:
callable_name = qualname(python_callable)
serialize_op["python_callable_name"] = callable_name
serialize_op["python_callable_name"] = qualname(python_callable, exclude_module=True)

serialize_op["task_type"] = getattr(op, "task_type", type(op).__name__)
serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__)
Expand Down
40 changes: 36 additions & 4 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import contextlib
import copy
import dataclasses
import functools
import importlib
import importlib.util
import json
Expand Down Expand Up @@ -2934,7 +2935,7 @@ def x(arg1, arg2, arg3):
},
"_disallow_kwargs_override": False,
"_expand_input_attr": "op_kwargs_expand_input",
"python_callable_name": qualname(x),
"python_callable_name": "test_taskflow_expand_serde.<locals>.x",
}

deserialized = BaseSerialization.deserialize(serialized)
Expand Down Expand Up @@ -3001,7 +3002,7 @@ def x(arg1, arg2, arg3):
"_task_module": "airflow.providers.standard.decorators.python",
"task_type": "_PythonDecoratedOperator",
"_operator_name": "@task",
"python_callable_name": qualname(x),
"python_callable_name": "test_taskflow_expand_kwargs_serde.<locals>.x",
"partial_kwargs": {
"op_args": [],
"op_kwargs": {
Expand Down Expand Up @@ -3172,11 +3173,42 @@ def test_python_callable_in_partial_kwargs():

serialized = OperatorSerialization.serialize_mapped_operator(operator)
assert "python_callable" not in serialized["partial_kwargs"]
assert serialized["partial_kwargs"]["python_callable_name"] == qualname(empty_function)
assert serialized["partial_kwargs"]["python_callable_name"] == "empty_function"

deserialized = OperatorSerialization.deserialize_operator(serialized)
assert "python_callable" not in deserialized.partial_kwargs
assert deserialized.partial_kwargs["python_callable_name"] == qualname(empty_function)
assert deserialized.partial_kwargs["python_callable_name"] == "empty_function"


def test_python_callable_name_uses_qualname_exclude_module():
"""Test python_callable_name is stable across bundle version changes."""
from airflow.providers.standard.operators.python import PythonOperator

# Module-level function
op1 = PythonOperator(task_id="task1", python_callable=empty_function)
serialized1 = OperatorSerialization.serialize_operator(op1)
assert serialized1["python_callable_name"] == "empty_function"

# Nested function
def outer():
def inner():
pass

return inner

inner_func = outer()
op2 = PythonOperator(task_id="task2", python_callable=inner_func)
serialized2 = OperatorSerialization.serialize_operator(op2)
assert (
serialized2["python_callable_name"]
== "test_python_callable_name_uses_qualname_exclude_module.<locals>.outer.<locals>.inner"
)

# functools.partial
partial_func = functools.partial(empty_function, x=1)
op3 = PythonOperator(task_id="task3", python_callable=partial_func)
serialized3 = OperatorSerialization.serialize_operator(op3)
assert serialized3["python_callable_name"] == "empty_function"


def test_handle_v1_serdag():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,24 @@ def import_string(dotted_path: str):
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')


def qualname(o: object | Callable, use_qualname: bool = False) -> str:
"""Convert an attribute/class/callable to a string importable by ``import_string``."""
def qualname(o: object | Callable, use_qualname: bool = False, exclude_module: bool = False) -> str:
"""
Convert an attribute/class/callable to a string.

By default, returns a string importable by ``import_string`` (includes module path).
With exclude_module=True, returns only the qualified name without module prefix,
useful for stable identification across deployments where module paths may vary.
"""
if callable(o) and hasattr(o, "__module__"):
if exclude_module:
if hasattr(o, "__qualname__"):
return o.__qualname__
if hasattr(o, "__name__"):
return o.__name__
# Handle functools.partial objects specifically (not just any object with 'func' attr)
if isinstance(o, functools.partial):
return qualname(o.func, exclude_module=True)
return type(o).__qualname__
if use_qualname and hasattr(o, "__qualname__"):
return f"{o.__module__}.{o.__qualname__}"
if hasattr(o, "__name__"):
Expand All @@ -79,6 +94,9 @@ def qualname(o: object | Callable, use_qualname: bool = False) -> str:
name = cls.__qualname__
module = cls.__module__

if exclude_module:
return name

if module and module != "__builtin__":
return f"{module}.{name}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@
# under the License.
from __future__ import annotations

import functools

import pytest

from airflow_shared.module_loading import import_string, is_valid_dotpath
from airflow_shared.module_loading import import_string, is_valid_dotpath, qualname


def _import_string():
pass


def _sample_function():
pass


class TestModuleImport:
def test_import_string(self):
cls = import_string("module_loading.test_module_loading._import_string")
Expand Down Expand Up @@ -56,3 +62,65 @@ class TestModuleLoading:
)
def test_is_valid_dotpath(self, path, expected):
assert is_valid_dotpath(path) == expected


class TestQualname:
def test_qualname_default_includes_module(self):
"""Test that qualname() by default includes the module path."""
result = qualname(_sample_function)
assert result == "module_loading.test_module_loading._sample_function"

def test_qualname_exclude_module_simple_function(self):
"""Test that exclude_module=True returns only the function name."""
result = qualname(_sample_function, exclude_module=True)
assert result == "_sample_function"

def test_qualname_exclude_module_nested_function(self):
"""Test that exclude_module=True works with nested functions."""

def outer():
def inner():
pass

return inner

inner_func = outer()
result = qualname(inner_func, exclude_module=True)
assert (
result
== "TestQualname.test_qualname_exclude_module_nested_function.<locals>.outer.<locals>.inner"
)

def test_qualname_exclude_module_functools_partial(self):
"""Test that exclude_module=True handles functools.partial correctly."""

def base_func(x, y):
pass

partial_func = functools.partial(base_func, x=1)
result = qualname(partial_func, exclude_module=True)
assert result == "TestQualname.test_qualname_exclude_module_functools_partial.<locals>.base_func"

def test_qualname_exclude_module_class(self):
"""Test that exclude_module=True works with classes."""

class MyClass:
pass

result = qualname(MyClass, exclude_module=True)
assert result == "TestQualname.test_qualname_exclude_module_class.<locals>.MyClass"

def test_qualname_exclude_module_instance(self):
"""Test that exclude_module=True works with class instances."""

class MyClass:
pass

instance = MyClass()
result = qualname(instance, exclude_module=True)
assert result == "TestQualname.test_qualname_exclude_module_instance.<locals>.MyClass"

def test_qualname_use_qualname_still_includes_module(self):
"""Test that use_qualname=True still includes module prefix."""
result = qualname(_sample_function, use_qualname=True)
assert result == "module_loading.test_module_loading._sample_function"