|
| 1 | +import pyarrow as pa |
| 2 | + |
| 3 | +from databend_udf import Headers, udf |
| 4 | + |
| 5 | + |
| 6 | +def _make_input_batch(func, columns): |
| 7 | + arrays = [ |
| 8 | + pa.array(values, type=field.type) |
| 9 | + for values, field in zip(columns, func._input_schema) |
| 10 | + ] |
| 11 | + return pa.RecordBatch.from_arrays(arrays, schema=func._input_schema) |
| 12 | + |
| 13 | + |
| 14 | +def _batch_to_rows(batch): |
| 15 | + data = batch.to_pydict() |
| 16 | + keys = list(data.keys()) |
| 17 | + rows = [] |
| 18 | + for idx in range(batch.num_rows): |
| 19 | + rows.append({key: data[key][idx] for key in keys}) |
| 20 | + return rows |
| 21 | + |
| 22 | + |
| 23 | +def _collect_rows(func, batch): |
| 24 | + rows = [] |
| 25 | + for out_batch in func.eval_batch(batch, Headers()): |
| 26 | + rows.extend(_batch_to_rows(out_batch)) |
| 27 | + return rows |
| 28 | + |
| 29 | + |
| 30 | +@udf( |
| 31 | + input_types=["INT"], |
| 32 | + result_type=[("value", "BIGINT"), ("square", "BIGINT")], |
| 33 | + batch_mode=True, |
| 34 | +) |
| 35 | +def table_returns_record_batch(nums): |
| 36 | + schema = pa.schema([pa.field("value", pa.int64()), pa.field("square", pa.int64())]) |
| 37 | + return pa.RecordBatch.from_arrays( |
| 38 | + [ |
| 39 | + pa.array(nums, type=pa.int64()), |
| 40 | + pa.array([n * n for n in nums], type=pa.int64()), |
| 41 | + ], |
| 42 | + schema=schema, |
| 43 | + ) |
| 44 | + |
| 45 | + |
| 46 | +@udf( |
| 47 | + input_types=["INT", "INT"], |
| 48 | + result_type=[("left", "INT"), ("right", "INT"), ("sum", "INT")], |
| 49 | + batch_mode=True, |
| 50 | +) |
| 51 | +def table_returns_iterable(lhs, rhs): |
| 52 | + return [ |
| 53 | + {"left": left_value, "right": right_value, "sum": left_value + right_value} |
| 54 | + for left_value, right_value in zip(lhs, rhs) |
| 55 | + ] |
| 56 | + |
| 57 | + |
| 58 | +def test_table_function_accepts_record_batch(): |
| 59 | + batch = _make_input_batch(table_returns_record_batch, [[1, 2, 3]]) |
| 60 | + rows = _collect_rows(table_returns_record_batch, batch) |
| 61 | + assert rows == [ |
| 62 | + {"value": 1, "square": 1}, |
| 63 | + {"value": 2, "square": 4}, |
| 64 | + {"value": 3, "square": 9}, |
| 65 | + ] |
| 66 | + |
| 67 | + |
| 68 | +def test_table_function_accepts_iterable_of_dicts(): |
| 69 | + batch = _make_input_batch(table_returns_iterable, [[1, 2], [10, 20]]) |
| 70 | + rows = _collect_rows(table_returns_iterable, batch) |
| 71 | + assert rows == [ |
| 72 | + {"left": 1, "right": 10, "sum": 11}, |
| 73 | + {"left": 2, "right": 20, "sum": 22}, |
| 74 | + ] |
0 commit comments