Skip to content

Commit 3ad0c4c

Browse files
authored
feat: add table-valued UDF support (#16)
* Add table-valued UDF support * Format files * Improve table UDF example * Add table-function helper tests
1 parent aad19c1 commit 3ad0c4c

File tree

6 files changed

+579
-140
lines changed

6 files changed

+579
-140
lines changed

python/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ def wait_concurrent(x):
4242
time.sleep(2)
4343
return x
4444

45+
# Define a table-valued function that emits multiple columns per row.
46+
@udf(
47+
input_types=["INT", "INT"],
48+
result_type=[("left", "INT"), ("right", "INT"), ("sum", "INT")],
49+
batch_mode=True,
50+
)
51+
def expand_pairs(left: List[int], right: List[int]):
52+
if len(left) != len(right):
53+
raise ValueError("Inputs must have the same length")
54+
return [
55+
{"left": l, "right": r, "sum": l + r} for l, r in zip(left, right)
56+
]
57+
4558
if __name__ == '__main__':
4659
# create a UDF server listening at '0.0.0.0:8815'
4760
server = UDFServer("0.0.0.0:8815")

python/databend_udf/client.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import pyarrow.flight as fl
88

99

10+
_SCHEMA_METADATA_INPUT_COUNT_KEY = b"x-databend-udf-input-count"
11+
12+
1013
class UDFClient:
1114
"""Simple client for calling UDF functions on a Databend UDF server."""
1215

@@ -29,6 +32,45 @@ def _get_cached_schema(self, function_name: str) -> pa.Schema:
2932
self._schema_cache[function_name] = info.schema
3033
return self._schema_cache[function_name]
3134

35+
@staticmethod
36+
def _get_input_field_count(schema: pa.Schema) -> int:
37+
"""Extract the number of input columns from schema metadata."""
38+
metadata = schema.metadata or {}
39+
key = _SCHEMA_METADATA_INPUT_COUNT_KEY
40+
if metadata and key in metadata:
41+
try:
42+
return int(metadata[key].decode("utf-8"))
43+
except (ValueError, AttributeError):
44+
pass
45+
46+
# Fallback for older servers without metadata: assume final column is output.
47+
return max(len(schema) - 1, 0)
48+
49+
@staticmethod
50+
def _decode_result_batch(batch: pa.RecordBatch) -> List[Any]:
51+
"""Convert a result RecordBatch into Python values."""
52+
num_columns = batch.num_columns
53+
num_rows = batch.num_rows
54+
55+
if num_columns == 0:
56+
return [{} for _ in range(num_rows)]
57+
58+
if num_columns == 1:
59+
column = batch.column(0)
60+
return [column[i].as_py() for i in range(num_rows)]
61+
62+
field_names = [
63+
batch.schema.field(i).name or f"col{i}" for i in range(num_columns)
64+
]
65+
66+
rows: List[Dict[str, Any]] = []
67+
for row_idx in range(num_rows):
68+
row: Dict[str, Any] = {}
69+
for col_idx, name in enumerate(field_names):
70+
row[name] = batch.column(col_idx)[row_idx].as_py()
71+
rows.append(row)
72+
return rows
73+
3274
def _prepare_function_call(
3375
self, function_name: str, args: tuple = (), kwargs: dict = None
3476
) -> tuple:
@@ -41,32 +83,35 @@ def _prepare_function_call(
4183
kwargs = kwargs or {}
4284
schema = self._get_cached_schema(function_name)
4385

86+
input_field_count = self._get_input_field_count(schema)
87+
total_fields = len(schema)
88+
if input_field_count > total_fields:
89+
raise ValueError(
90+
f"Function '{function_name}' schema metadata is invalid (input count {input_field_count} > total fields {total_fields})"
91+
)
92+
93+
input_fields = [schema.field(i) for i in range(input_field_count)]
94+
input_schema = pa.schema(input_fields)
95+
4496
# Validate arguments
4597
if args and kwargs:
4698
raise ValueError("Cannot mix positional and keyword arguments")
4799

48100
if args:
49101
# Positional arguments - validate count first
50-
total_fields = len(schema)
51-
expected_input_count = total_fields - 1 # Last field is always output
52-
53-
if len(args) != expected_input_count:
102+
if len(args) != input_field_count:
54103
raise ValueError(
55-
f"Function '{function_name}' expects {expected_input_count} arguments, got {len(args)}"
104+
f"Function '{function_name}' expects {input_field_count} arguments, got {len(args)}"
56105
)
57106

58107
if len(args) == 0:
59-
input_schema = pa.schema([])
60108
# For zero-argument functions, create a batch with 1 row and no columns
61109
dummy_array = pa.array([None])
62110
temp_batch = pa.RecordBatch.from_arrays(
63111
[dummy_array], schema=pa.schema([pa.field("dummy", pa.null())])
64112
)
65113
batch = temp_batch.select([])
66114
else:
67-
input_fields = [schema.field(i) for i in range(len(args))]
68-
input_schema = pa.schema(input_fields)
69-
70115
arrays = []
71116
for i, arg in enumerate(args):
72117
if isinstance(arg, list):
@@ -85,13 +130,6 @@ def _prepare_function_call(
85130
)
86131
batch = temp_batch.select([])
87132
else:
88-
# Extract only input fields (exclude the last field which is output)
89-
# The schema contains input fields + 1 output field
90-
total_fields = len(schema)
91-
num_input_fields = total_fields - 1 # Last field is always output
92-
input_fields = [schema.field(i) for i in range(num_input_fields)]
93-
input_schema = pa.schema(input_fields)
94-
95133
# Validate kwargs
96134
expected_fields = {field.name for field in input_schema}
97135
provided_fields = set(kwargs.keys())
@@ -223,12 +261,10 @@ def call_function(
223261
writer.write_batch(batch)
224262
writer.done_writing()
225263

226-
# Get results
227264
results = []
228265
for result_chunk in reader:
229266
result_batch = result_chunk.data
230-
for i in range(result_batch.num_rows):
231-
results.append(result_batch.column(0)[i].as_py())
267+
results.extend(self._decode_result_batch(result_batch))
232268

233269
return results
234270

@@ -264,12 +300,10 @@ def call_function_batch(
264300
writer.write_batch(batch)
265301
writer.done_writing()
266302

267-
# Get results
268303
results = []
269304
for result_chunk in reader:
270305
result_batch = result_chunk.data
271-
for i in range(result_batch.num_rows):
272-
results.append(result_batch.column(0)[i].as_py())
306+
results.extend(self._decode_result_batch(result_batch))
273307

274308
return results
275309

0 commit comments

Comments
 (0)