Skip to content

Commit 150e97f

Browse files
committed
Add table-function helper tests
1 parent 7ada5a2 commit 150e97f

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pyarrow as pa
2+
3+
from databend_udf import Headers, udf
4+
5+
6+
def _make_input_batch(func, columns):
7+
arrays = [
8+
pa.array(values, type=field.type)
9+
for values, field in zip(columns, func._input_schema)
10+
]
11+
return pa.RecordBatch.from_arrays(arrays, schema=func._input_schema)
12+
13+
14+
def _batch_to_rows(batch):
15+
data = batch.to_pydict()
16+
keys = list(data.keys())
17+
rows = []
18+
for idx in range(batch.num_rows):
19+
rows.append({key: data[key][idx] for key in keys})
20+
return rows
21+
22+
23+
def _collect_rows(func, batch):
24+
rows = []
25+
for out_batch in func.eval_batch(batch, Headers()):
26+
rows.extend(_batch_to_rows(out_batch))
27+
return rows
28+
29+
30+
@udf(
31+
input_types=["INT"],
32+
result_type=[("value", "BIGINT"), ("square", "BIGINT")],
33+
batch_mode=True,
34+
)
35+
def table_returns_record_batch(nums):
36+
schema = pa.schema([pa.field("value", pa.int64()), pa.field("square", pa.int64())])
37+
return pa.RecordBatch.from_arrays(
38+
[
39+
pa.array(nums, type=pa.int64()),
40+
pa.array([n * n for n in nums], type=pa.int64()),
41+
],
42+
schema=schema,
43+
)
44+
45+
46+
@udf(
47+
input_types=["INT", "INT"],
48+
result_type=[("left", "INT"), ("right", "INT"), ("sum", "INT")],
49+
batch_mode=True,
50+
)
51+
def table_returns_iterable(lhs, rhs):
52+
return [
53+
{"left": left_value, "right": right_value, "sum": left_value + right_value}
54+
for left_value, right_value in zip(lhs, rhs)
55+
]
56+
57+
58+
def test_table_function_accepts_record_batch():
59+
batch = _make_input_batch(table_returns_record_batch, [[1, 2, 3]])
60+
rows = _collect_rows(table_returns_record_batch, batch)
61+
assert rows == [
62+
{"value": 1, "square": 1},
63+
{"value": 2, "square": 4},
64+
{"value": 3, "square": 9},
65+
]
66+
67+
68+
def test_table_function_accepts_iterable_of_dicts():
69+
batch = _make_input_batch(table_returns_iterable, [[1, 2], [10, 20]])
70+
rows = _collect_rows(table_returns_iterable, batch)
71+
assert rows == [
72+
{"left": 1, "right": 10, "sum": 11},
73+
{"left": 2, "right": 20, "sum": 22},
74+
]

0 commit comments

Comments
 (0)