From a3c108f5a08e485627dc8550c4e0b1d92de90a75 Mon Sep 17 00:00:00 2001 From: Dejan Simic <10134699+simicd@users.noreply.github.com> Date: Mon, 6 Mar 2023 15:43:40 +0100 Subject: [PATCH] docs: Example of calling Python UDF & UDAF in SQL (#258) * Document UDF calls in SQL * Remove unnecessary imports * FIx example --- README.md | 4 ++ datafusion/__init__.py | 4 +- examples/sql-using-python-udaf.py | 91 +++++++++++++++++++++++++++++++ examples/sql-using-python-udf.py | 65 ++++++++++++++++++++++ 4 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 examples/sql-using-python-udaf.py create mode 100644 examples/sql-using-python-udf.py diff --git a/README.md b/README.md index 923b6be0..7c29defd 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,11 @@ See [examples](examples/README.md) for more information. - [Query a Parquet file using SQL](./examples/sql-parquet.py) - [Query a Parquet file using the DataFrame API](./examples/dataframe-parquet.py) - [Run a SQL query and store the results in a Pandas DataFrame](./examples/sql-to-pandas.py) +- [Run a SQL query with a Python user-defined function (UDF)](./examples/sql-using-python-udf.py) +- [Run a SQL query with a Python user-defined aggregation function (UDAF)](./examples/sql-using-python-udaf.py) - [Query PyArrow Data](./examples/query-pyarrow-data.py) +- [Create dataframe](./examples/import.py) +- [Export dataframe](./examples/export.py) ### Running User-Defined Python Code diff --git a/datafusion/__init__.py b/datafusion/__init__.py index f5583c29..a7878e1b 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -171,7 +171,7 @@ def udf(func, input_types, return_type, volatility, name=None): if not callable(func): raise TypeError("`func` argument must be callable") if name is None: - name = func.__qualname__ + name = func.__qualname__.lower() return ScalarUDF( name=name, func=func, @@ -190,7 +190,7 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None): "`accum` must implement the abstract base class Accumulator" ) if name is None: - name = accum.__qualname__ + name = accum.__qualname__.lower() return AggregateUDF( name=name, accumulator=accum, diff --git a/examples/sql-using-python-udaf.py b/examples/sql-using-python-udaf.py new file mode 100644 index 00000000..9aacc5d4 --- /dev/null +++ b/examples/sql-using-python-udaf.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import udaf, SessionContext, Accumulator +import pyarrow as pa + + +# Define a user-defined aggregation function (UDAF) +class MyAccumulator(Accumulator): + """ + Interface of a user-defined accumulation. + """ + + def __init__(self): + self._sum = pa.scalar(0.0) + + def update(self, values: pa.Array) -> None: + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pa.scalar( + self._sum.as_py() + pa.compute.sum(values).as_py() + ) + + def merge(self, states: pa.Array) -> None: + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pa.scalar( + self._sum.as_py() + pa.compute.sum(states).as_py() + ) + + def state(self) -> pa.Array: + return pa.array([self._sum.as_py()]) + + def evaluate(self) -> pa.Scalar: + return self._sum + + +my_udaf = udaf( + MyAccumulator, + pa.float64(), + pa.float64(), + [pa.float64()], + "stable", + # This will be the name of the UDAF in SQL + # If not specified it will by default the same as accumulator class name + name="my_accumulator", +) + +# Create a context +ctx = SessionContext() + +# Create a datafusion DataFrame from a Python dictionary +source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}) +# Dataframe: +# +---+---+ +# | a | b | +# +---+---+ +# | 1 | 4 | +# | 1 | 5 | +# | 3 | 6 | +# +---+---+ + +# Register UDF for use in SQL +ctx.register_udaf(my_udaf) + +# Query the DataFrame using SQL +table_name = ctx.catalog().database().names().pop() +result_df = ctx.sql( + f"select a, my_accumulator(b) as b_aggregated from {table_name} group by a order by a" +) +# Dataframe: +# +---+--------------+ +# | a | b_aggregated | +# +---+--------------+ +# | 1 | 9 | +# | 3 | 6 | +# +---+--------------+ +assert result_df.to_pydict()["a"] == [1, 3] +assert result_df.to_pydict()["b_aggregated"] == [9, 6] diff --git a/examples/sql-using-python-udf.py b/examples/sql-using-python-udf.py new file mode 100644 index 00000000..717b88e2 --- /dev/null +++ b/examples/sql-using-python-udf.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import udf, SessionContext +import pyarrow as pa + + +# Define a user-defined function (UDF) +def is_null(array: pa.Array) -> pa.Array: + return array.is_null() + + +is_null_arr = udf( + is_null, + [pa.int64()], + pa.bool_(), + "stable", + # This will be the name of the UDF in SQL + # If not specified it will by default the same as Python function name + name="is_null", +) + +# Create a context +ctx = SessionContext() + +# Create a datafusion DataFrame from a Python dictionary +source_df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]}) +# Dataframe: +# +---+---+ +# | a | b | +# +---+---+ +# | 1 | 4 | +# | 2 | | +# | 3 | 6 | +# +---+---+ + +# Register UDF for use in SQL +ctx.register_udf(is_null_arr) + +# Query the DataFrame using SQL +table_name = ctx.catalog().database().names().pop() +result_df = ctx.sql(f"select a, is_null(b) as b_is_null from {table_name}") +# Dataframe: +# +---+-----------+ +# | a | b_is_null | +# +---+-----------+ +# | 1 | false | +# | 2 | true | +# | 3 | false | +# +---+-----------+ +assert result_df.to_pydict()["b_is_null"] == [False, True, False]