Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,20 @@ asyncio_default_fixture_loop_scope = "function"

# Enable docstring linting using the google style guide
[tool.ruff.lint]
select = ["ALL" ]
select = ["ALL"]
ignore = [
"A001", # Allow using words like min as variable names
"A002", # Allow using words like filter as variable names
"ANN401", # Allow Any for wrapper classes
"COM812", # Recommended to ignore these rules when using with ruff-format
"FIX002", # Allow TODO lines - consider removing at some point
"FBT001", # Allow boolean positional args
"FBT002", # Allow boolean positional args
"ISC001", # Recommended to ignore these rules when using with ruff-format
"SLF001", # Allow accessing private members
"A001", # Allow using words like min as variable names
"A002", # Allow using words like filter as variable names
"ANN401", # Allow Any for wrapper classes
"COM812", # Recommended to ignore these rules when using with ruff-format
"FIX002", # Allow TODO lines - consider removing at some point
"FBT001", # Allow boolean positional args
"FBT002", # Allow boolean positional args
"ISC001", # Recommended to ignore these rules when using with ruff-format
"SLF001", # Allow accessing private members
"TD002",
"TD003", # Allow TODO lines
"UP007", # Disallowing Union is pedantic
"TD003", # Allow TODO lines
"UP007", # Disallowing Union is pedantic
# TODO: Enable all of the following, but this PR is getting too large already
"PLR0913",
"TRY003",
Expand Down Expand Up @@ -129,25 +129,33 @@ extend-allowed-calls = ["lit", "datafusion.lit"]
]
"examples/*" = ["D", "W505", "E501", "T201", "S101"]
"dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"]
"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S", "SIM", "EXE", "UP"]
"benchmarks/*" = [
"D",
"F",
"T",
"BLE",
"FURB",
"PLR",
"E",
"TD",
"TRY",
"S",
"SIM",
"EXE",
"UP",
]
"docs/*" = ["D"]

[tool.codespell]
skip = [
"./target",
"uv.lock",
"./python/tests/test_functions.py"
]
skip = ["./target", "uv.lock", "./python/tests/test_functions.py"]
count = true
ignore-words-list = [
"ans",
"IST"
]
ignore-words-list = ["ans", "IST"]

[dependency-groups]
dev = [
"maturin>=1.8.1",
"numpy>1.25.0",
"pyarrow>=19.0.0",
Copy link
Contributor Author

@kosiew kosiew Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change added to ensure pa.uuid() is available for test_udf.py.

https://arrow.apache.org/docs/19.0/python/generated/pyarrow.uuid.html is the lowest version which contains pyarrow.uuid.

The rest are VSCode automatic formatting.

"pre-commit>=4.0.0",
"pytest>=7.4.4",
"pytest-asyncio>=0.23.3",
Expand Down
184 changes: 151 additions & 33 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,39 @@
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
from typing import (
Any,
Callable,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
cast,
overload,
)

import pyarrow as pa

import datafusion._internal as df_internal
from datafusion.expr import Expr

if TYPE_CHECKING:
_R = TypeVar("_R", bound=pa.DataType)
PyArrowArray = Union[pa.Array, pa.ChunkedArray]
# Type alias for array batches exchanged with Python scalar UDFs.
#
# We need two related but different annotations here:
# - `PyArrowArray` is the concrete union type (pa.Array | pa.ChunkedArray)
# that is convenient for user-facing callables and casts. Use this when
# annotating or checking values that may be either an Array or
# a ChunkedArray.
# - `PyArrowArrayT` is a constrained `TypeVar` over the two concrete
# array flavors. Keeping a generic TypeVar allows helpers like
# `_wrap_extension_value` and `_wrap_udf_function` to remain generic
# and preserve the specific array "flavor" (Array vs ChunkedArray)
# flowing through them, rather than collapsing everything to the
# wide union. This improves type-checking and keeps return types
# precise in the wrapper logic.
PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray)


class Volatility(Enum):
Expand Down Expand Up @@ -77,6 +101,87 @@ def __str__(self) -> str:
return self.name.lower()


def _clone_field(field: pa.Field) -> pa.Field:
"""Return a deep copy of ``field`` including its DataType."""
return pa.schema([field]).field(0)


def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field:
if isinstance(value, pa.Field):
return _clone_field(value)
if isinstance(value, pa.DataType):
return _clone_field(pa.field(default_name, value))
msg = "Expected a pyarrow.DataType or pyarrow.Field"
raise TypeError(msg)


def _normalize_input_fields(
values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
) -> list[pa.Field]:
if isinstance(values, (pa.DataType, pa.Field)):
sequence: Sequence[pa.DataType | pa.Field] = [values]
elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)):
sequence = values
else:
msg = "input_types must be a DataType, Field, or a sequence of them"
raise TypeError(msg)

return [
_normalize_field(value, default_name=f"arg_{idx}")
for idx, value in enumerate(sequence)
]


def _normalize_return_field(
value: pa.DataType | pa.Field,
*,
name: str,
) -> pa.Field:
default_name = f"{name}_result" if name else "result"
return _normalize_field(value, default_name=default_name)


def _wrap_extension_value(
value: PyArrowArrayT, data_type: pa.DataType
) -> PyArrowArrayT:
storage_type = getattr(data_type, "storage_type", None)
wrap_array = getattr(data_type, "wrap_array", None)
if storage_type is None or wrap_array is None:
return value
if isinstance(value, pa.Array) and value.type.equals(storage_type):
return wrap_array(value)
if isinstance(value, pa.ChunkedArray) and value.type.equals(storage_type):
wrapped_chunks = [wrap_array(chunk) for chunk in value.chunks]
if not wrapped_chunks:
empty_storage = pa.array([], type=storage_type)
return wrap_array(empty_storage)
return pa.chunked_array(wrapped_chunks, type=data_type)
return value


def _wrap_udf_function(
func: Callable[..., PyArrowArrayT],
input_fields: Sequence[pa.Field],
return_field: pa.Field,
) -> Callable[..., PyArrowArrayT]:
def wrapper(*args: Any, **kwargs: Any) -> PyArrowArrayT:
if args:
converted_args: list[Any] = list(args)
for idx, field in enumerate(input_fields):
if idx >= len(converted_args):
break
converted_args[idx] = _wrap_extension_value(
cast(PyArrowArray, converted_args[idx]),
field.type,
)
else:
converted_args = []
result = func(*converted_args, **kwargs)
return _wrap_extension_value(result, return_field.type)

return wrapper


class ScalarUDFExportable(Protocol):
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""

Expand All @@ -93,9 +198,9 @@ class ScalarUDF:
def __init__(
self,
name: str,
func: Callable[..., _R],
input_types: pa.DataType | list[pa.DataType],
return_type: _R,
func: Callable[..., PyArrowArray] | ScalarUDFExportable,
input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
return_type: pa.DataType | pa.Field,
volatility: Volatility | str,
) -> None:
"""Instantiate a scalar user-defined function (UDF).
Expand All @@ -105,10 +210,11 @@ def __init__(
if hasattr(func, "__datafusion_scalar_udf__"):
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
return
if isinstance(input_types, pa.DataType):
input_types = [input_types]
normalized_inputs = _normalize_input_fields(input_types)
normalized_return = _normalize_return_field(return_type, name=name)
wrapped_func = _wrap_udf_function(func, normalized_inputs, normalized_return)
self._udf = df_internal.ScalarUDF(
name, func, input_types, return_type, str(volatility)
name, wrapped_func, normalized_inputs, normalized_return, str(volatility)
)

def __repr__(self) -> str:
Expand All @@ -127,18 +233,18 @@ def __call__(self, *args: Expr) -> Expr:
@overload
@staticmethod
def udf(
input_types: list[pa.DataType],
return_type: _R,
input_types: list[pa.DataType | pa.Field],
return_type: pa.DataType | pa.Field,
volatility: Volatility | str,
name: Optional[str] = None,
) -> Callable[..., ScalarUDF]: ...
) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]: ...

@overload
@staticmethod
def udf(
func: Callable[..., _R],
input_types: list[pa.DataType],
return_type: _R,
func: Callable[..., PyArrowArray],
input_types: list[pa.DataType | pa.Field],
return_type: pa.DataType | pa.Field,
volatility: Volatility | str,
name: Optional[str] = None,
) -> ScalarUDF: ...
Expand All @@ -164,10 +270,15 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
backed ScalarUDF within a PyCapsule, you can pass this parameter
and ignore the rest. They will be determined directly from the
underlying function. See the online documentation for more information.
input_types (list[pa.DataType]): The data types of the arguments
to ``func``. This list must be of the same length as the number of
arguments.
return_type (_R): The data type of the return value from the function.
The callable should accept and return :class:`pyarrow.Array` or
:class:`pyarrow.ChunkedArray` values.
input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
This list must be of the same length as the number of arguments. Pass
:class:`pyarrow.Field` instances when you need to declare extension
metadata for an argument.
return_type (pa.DataType | pa.Field): The return type of the function.
Supply a :class:`pyarrow.Field` when the result should expose
extension metadata to downstream consumers.
volatility (Volatility | str): See `Volatility` for allowed values.
name (Optional[str]): A descriptive name for the function.

Expand All @@ -179,8 +290,13 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417

def double_func(x):
return x * 2
double_udf = udf(double_func, [pa.int32()], pa.int32(),
"volatile", "double_it")
double_udf = udf(
double_func,
[pa.int32()],
pa.int32(),
"volatile",
"double_it",
)

Example: Using ``udf`` as a decorator::

Expand All @@ -190,9 +306,9 @@ def double_udf(x):
"""

def _function(
func: Callable[..., _R],
input_types: list[pa.DataType],
return_type: _R,
func: Callable[..., PyArrowArray],
input_types: list[pa.DataType | pa.Field],
return_type: pa.DataType | pa.Field,
volatility: Volatility | str,
name: Optional[str] = None,
) -> ScalarUDF:
Expand All @@ -213,18 +329,18 @@ def _function(
)

def _decorator(
input_types: list[pa.DataType],
return_type: _R,
input_types: list[pa.DataType | pa.Field],
return_type: pa.DataType | pa.Field,
volatility: Volatility | str,
name: Optional[str] = None,
) -> Callable:
def decorator(func: Callable):
) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]:
def decorator(func: Callable[..., PyArrowArray]) -> Callable[..., Expr]:
udf_caller = ScalarUDF.udf(
func, input_types, return_type, volatility, name
)

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any):
def wrapper(*args: Any, **kwargs: Any) -> Expr:
return udf_caller(*args, **kwargs)

return wrapper
Expand Down Expand Up @@ -357,10 +473,12 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
This class allows you to define an aggregate function that can be used in
data aggregation or window function calls.

Usage:
- As a function: ``udaf(accum, input_types, return_type, state_type, volatility, name)``.
- As a decorator: ``@udaf(input_types, return_type, state_type, volatility, name)``.
When using ``udaf`` as a decorator, do not pass ``accum`` explicitly.
Usage:
- As a function: ``udaf(accum, input_types, return_type, state_type,``
``volatility, name)``.
- As a decorator: ``@udaf(input_types, return_type, state_type,``
``volatility, name)``.
When using ``udaf`` as a decorator, do not pass ``accum`` explicitly.
Comment on lines +476 to +481
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the formatting got changed. Is this intentional?


Function example:

Expand Down
Loading