77import pyarrow .flight as fl
88
99
10+ _SCHEMA_METADATA_INPUT_COUNT_KEY = b"x-databend-udf-input-count"
11+
12+
1013class UDFClient :
1114 """Simple client for calling UDF functions on a Databend UDF server."""
1215
@@ -29,6 +32,45 @@ def _get_cached_schema(self, function_name: str) -> pa.Schema:
2932 self ._schema_cache [function_name ] = info .schema
3033 return self ._schema_cache [function_name ]
3134
35+ @staticmethod
36+ def _get_input_field_count (schema : pa .Schema ) -> int :
37+ """Extract the number of input columns from schema metadata."""
38+ metadata = schema .metadata or {}
39+ key = _SCHEMA_METADATA_INPUT_COUNT_KEY
40+ if metadata and key in metadata :
41+ try :
42+ return int (metadata [key ].decode ("utf-8" ))
43+ except (ValueError , AttributeError ):
44+ pass
45+
46+ # Fallback for older servers without metadata: assume final column is output.
47+ return max (len (schema ) - 1 , 0 )
48+
49+ @staticmethod
50+ def _decode_result_batch (batch : pa .RecordBatch ) -> List [Any ]:
51+ """Convert a result RecordBatch into Python values."""
52+ num_columns = batch .num_columns
53+ num_rows = batch .num_rows
54+
55+ if num_columns == 0 :
56+ return [{} for _ in range (num_rows )]
57+
58+ if num_columns == 1 :
59+ column = batch .column (0 )
60+ return [column [i ].as_py () for i in range (num_rows )]
61+
62+ field_names = [
63+ batch .schema .field (i ).name or f"col{ i } " for i in range (num_columns )
64+ ]
65+
66+ rows : List [Dict [str , Any ]] = []
67+ for row_idx in range (num_rows ):
68+ row : Dict [str , Any ] = {}
69+ for col_idx , name in enumerate (field_names ):
70+ row [name ] = batch .column (col_idx )[row_idx ].as_py ()
71+ rows .append (row )
72+ return rows
73+
3274 def _prepare_function_call (
3375 self , function_name : str , args : tuple = (), kwargs : dict = None
3476 ) -> tuple :
@@ -41,32 +83,35 @@ def _prepare_function_call(
4183 kwargs = kwargs or {}
4284 schema = self ._get_cached_schema (function_name )
4385
86+ input_field_count = self ._get_input_field_count (schema )
87+ total_fields = len (schema )
88+ if input_field_count > total_fields :
89+ raise ValueError (
90+ f"Function '{ function_name } ' schema metadata is invalid (input count { input_field_count } > total fields { total_fields } )"
91+ )
92+
93+ input_fields = [schema .field (i ) for i in range (input_field_count )]
94+ input_schema = pa .schema (input_fields )
95+
4496 # Validate arguments
4597 if args and kwargs :
4698 raise ValueError ("Cannot mix positional and keyword arguments" )
4799
48100 if args :
49101 # Positional arguments - validate count first
50- total_fields = len (schema )
51- expected_input_count = total_fields - 1 # Last field is always output
52-
53- if len (args ) != expected_input_count :
102+ if len (args ) != input_field_count :
54103 raise ValueError (
55- f"Function '{ function_name } ' expects { expected_input_count } arguments, got { len (args )} "
104+ f"Function '{ function_name } ' expects { input_field_count } arguments, got { len (args )} "
56105 )
57106
58107 if len (args ) == 0 :
59- input_schema = pa .schema ([])
60108 # For zero-argument functions, create a batch with 1 row and no columns
61109 dummy_array = pa .array ([None ])
62110 temp_batch = pa .RecordBatch .from_arrays (
63111 [dummy_array ], schema = pa .schema ([pa .field ("dummy" , pa .null ())])
64112 )
65113 batch = temp_batch .select ([])
66114 else :
67- input_fields = [schema .field (i ) for i in range (len (args ))]
68- input_schema = pa .schema (input_fields )
69-
70115 arrays = []
71116 for i , arg in enumerate (args ):
72117 if isinstance (arg , list ):
@@ -85,13 +130,6 @@ def _prepare_function_call(
85130 )
86131 batch = temp_batch .select ([])
87132 else :
88- # Extract only input fields (exclude the last field which is output)
89- # The schema contains input fields + 1 output field
90- total_fields = len (schema )
91- num_input_fields = total_fields - 1 # Last field is always output
92- input_fields = [schema .field (i ) for i in range (num_input_fields )]
93- input_schema = pa .schema (input_fields )
94-
95133 # Validate kwargs
96134 expected_fields = {field .name for field in input_schema }
97135 provided_fields = set (kwargs .keys ())
@@ -223,12 +261,10 @@ def call_function(
223261 writer .write_batch (batch )
224262 writer .done_writing ()
225263
226- # Get results
227264 results = []
228265 for result_chunk in reader :
229266 result_batch = result_chunk .data
230- for i in range (result_batch .num_rows ):
231- results .append (result_batch .column (0 )[i ].as_py ())
267+ results .extend (self ._decode_result_batch (result_batch ))
232268
233269 return results
234270
@@ -264,12 +300,10 @@ def call_function_batch(
264300 writer .write_batch (batch )
265301 writer .done_writing ()
266302
267- # Get results
268303 results = []
269304 for result_chunk in reader :
270305 result_batch = result_chunk .data
271- for i in range (result_batch .num_rows ):
272- results .append (result_batch .column (0 )[i ].as_py ())
306+ results .extend (self ._decode_result_batch (result_batch ))
273307
274308 return results
275309
0 commit comments