Skip to content

Commit 7ada5a2

Browse files
committed
Improve table UDF example
1 parent e139844 commit 7ada5a2

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

python/README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,18 @@ def wait_concurrent(x):
4242
time.sleep(2)
4343
return x
4444

45-
# Define a table-valued function that expands inputs into multiple rows.
45+
# Define a table-valued function that emits multiple columns per row.
4646
@udf(
47-
input_types=["INT"],
48-
result_type=[("num", "INT"), ("double_num", "INT")],
47+
input_types=["INT", "INT"],
48+
result_type=[("left", "INT"), ("right", "INT"), ("sum", "INT")],
4949
batch_mode=True,
5050
)
51-
def expand_numbers(nums: List[int]):
52-
return [{"num": value, "double_num": value * 2} for value in nums]
51+
def expand_pairs(left: List[int], right: List[int]):
52+
if len(left) != len(right):
53+
raise ValueError("Inputs must have the same length")
54+
return [
55+
{"left": l, "right": r, "sum": l + r} for l, r in zip(left, right)
56+
]
5357

5458
if __name__ == '__main__':
5559
# create a UDF server listening at '0.0.0.0:8815'

python/tests/servers/basic_server.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,25 @@ def gcd(x, y):
2222

2323

2424
@udf(
25-
input_types=["INT"],
26-
result_type=[("num", "INT"), ("double_num", "INT")],
25+
input_types=["INT", "INT"],
26+
result_type=[("left", "INT"), ("right", "INT"), ("sum", "INT")],
2727
batch_mode=True,
2828
)
29-
def expand_numbers(nums):
30-
return [{"num": value, "double_num": value * 2} for value in nums]
29+
def expand_pairs(lhs, rhs):
30+
if len(lhs) != len(rhs):
31+
raise ValueError("lhs and rhs must have the same length")
32+
return [
33+
{"left": left, "right": right, "sum": left + right}
34+
for left, right in zip(lhs, rhs)
35+
]
3136

3237

3338
def create_basic_server(port=8815):
3439
"""Create server with basic arithmetic functions."""
3540
server = UDFServer(f"0.0.0.0:{port}")
3641
server.add_function(add_two_ints)
3742
server.add_function(gcd)
38-
server.add_function(expand_numbers)
43+
server.add_function(expand_pairs)
3944
return server
4045

4146

python/tests/test_simple_udf.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,14 @@ def test_table_function_returns_records(basic_server):
9292
"""Test table-valued function returning multiple columns."""
9393
client = basic_server.get_client()
9494

95-
batch_result = client.call_function_batch("expand_numbers", nums=[1, 2, 3])
95+
batch_result = client.call_function_batch(
96+
"expand_pairs", lhs=[1, 2, 3], rhs=[10, 20, 30]
97+
)
9698
assert batch_result == [
97-
{"num": 1, "double_num": 2},
98-
{"num": 2, "double_num": 4},
99-
{"num": 3, "double_num": 6},
99+
{"left": 1, "right": 10, "sum": 11},
100+
{"left": 2, "right": 20, "sum": 22},
101+
{"left": 3, "right": 30, "sum": 33},
100102
]
101103

102-
single_result = client.call_function("expand_numbers", 5)
103-
assert single_result == [{"num": 5, "double_num": 10}]
104+
single_result = client.call_function("expand_pairs", 4, 6)
105+
assert single_result == [{"left": 4, "right": 6, "sum": 10}]

0 commit comments

Comments
 (0)