Skip to content

Commit e139844

Browse files
committed
Format files
1 parent 510abf9 commit e139844

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

python/databend_udf/udf.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ class CallableFunction(UserDefinedFunction):
388388
This handles parsing function parameters, stage bindings, and schema
389389
construction so both scalar and table-valued functions can reuse the logic.
390390
"""
391+
391392
_func: Callable
392393
_stage_ref_names: List[str]
393394
_headers_param: Optional[str]
@@ -504,7 +505,9 @@ def __init__(
504505
def _ensure_headers(self, headers: Optional[Headers]) -> Headers:
505506
return headers if isinstance(headers, Headers) else Headers(headers)
506507

507-
def _resolve_stage_locations(self, headers_obj: Headers) -> Dict[str, StageLocation]:
508+
def _resolve_stage_locations(
509+
self, headers_obj: Headers
510+
) -> Dict[str, StageLocation]:
508511
stage_locations: Dict[str, StageLocation] = {}
509512
if self._stage_ref_names:
510513
stage_locations_by_ref = headers_obj.require_stage_locations(
@@ -839,8 +842,7 @@ def _is_column_mapping(self, mapping: Mapping[str, Any]) -> bool:
839842
if not mapping:
840843
return True
841844
return all(
842-
isinstance(values, Sequence)
843-
and not isinstance(values, (bytes, str))
845+
isinstance(values, Sequence) and not isinstance(values, (bytes, str))
844846
for values in mapping.values()
845847
)
846848

@@ -879,7 +881,9 @@ def _normalize_result_type(
879881
field = _to_arrow_field(item).with_name(f"col{idx}")
880882
fields.append(field)
881883
if not fields:
882-
raise ValueError("Table function result_type must contain at least one column")
884+
raise ValueError(
885+
"Table function result_type must contain at least one column"
886+
)
883887
_ensure_unique_names([field.name for field in fields])
884888
return fields, True
885889

@@ -918,6 +922,7 @@ def udf(
918922
)
919923

920924
if is_table:
925+
921926
def decorator(f):
922927
return TableFunction(
923928
f,
@@ -933,6 +938,7 @@ def decorator(f):
933938
return decorator
934939

935940
if io_threads is not None and io_threads > 1:
941+
936942
def decorator(f):
937943
return ScalarFunction(
938944
f,
@@ -960,6 +966,7 @@ def decorator(f):
960966

961967
return decorator
962968

969+
963970
class UDFServer(FlightServerBase):
964971
"""
965972
A server that provides user-defined functions to clients.
@@ -1062,7 +1069,9 @@ def get_flight_info(self, context, descriptor):
10621069
# return the concatenation of input and output schema
10631070
full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema))
10641071
metadata = dict(full_schema.metadata.items()) if full_schema.metadata else {}
1065-
metadata[_SCHEMA_METADATA_INPUT_COUNT_KEY] = str(len(udf._input_schema)).encode("utf-8")
1072+
metadata[_SCHEMA_METADATA_INPUT_COUNT_KEY] = str(len(udf._input_schema)).encode(
1073+
"utf-8"
1074+
)
10661075
full_schema = full_schema.with_metadata(metadata)
10671076
return FlightInfo(
10681077
schema=full_schema,

python/tests/test_simple_udf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ def test_table_function_returns_records(basic_server):
9292
"""Test table-valued function returning multiple columns."""
9393
client = basic_server.get_client()
9494

95-
batch_result = client.call_function_batch(
96-
"expand_numbers", nums=[1, 2, 3]
97-
)
95+
batch_result = client.call_function_batch("expand_numbers", nums=[1, 2, 3])
9896
assert batch_result == [
9997
{"num": 1, "double_num": 2},
10098
{"num": 2, "double_num": 4},

0 commit comments

Comments
 (0)