Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Example of calling Python UDF & UDAF in SQL #258

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ See [examples](examples/README.md) for more information.
- [Query a Parquet file using SQL](./examples/sql-parquet.py)
- [Query a Parquet file using the DataFrame API](./examples/dataframe-parquet.py)
- [Run a SQL query and store the results in a Pandas DataFrame](./examples/sql-to-pandas.py)
- [Run a SQL query with a Python user-defined function (UDF)](./examples/sql-using-python-udf.py)
- [Run a SQL query with a Python user-defined aggregation function (UDAF)](./examples/sql-using-python-udaf.py)
- [Query PyArrow Data](./examples/query-pyarrow-data.py)
- [Create dataframe](./examples/import.py)
- [Export dataframe](./examples/export.py)

### Running User-Defined Python Code

Expand Down
4 changes: 2 additions & 2 deletions datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def udf(func, input_types, return_type, volatility, name=None):
if not callable(func):
raise TypeError("`func` argument must be callable")
if name is None:
name = func.__qualname__
name = func.__qualname__.lower()
return ScalarUDF(
name=name,
func=func,
Expand All @@ -190,7 +190,7 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None):
"`accum` must implement the abstract base class Accumulator"
)
if name is None:
name = accum.__qualname__
name = accum.__qualname__.lower()
Copy link
Contributor Author

@simicd simicd Mar 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously calling the UDAF in SQL would fail if name is not specified (i.e. both MyAccumulator(b) and myaccumulator(b) would fail when running e.g. result_df = ctx.sql(f"select a, myaccumulator(b) as b_aggregated from {table_name}") ). With this change both the uppercase & lowercase variant will work.

return AggregateUDF(
name=name,
accumulator=accum,
Expand Down
91 changes: 91 additions & 0 deletions examples/sql-using-python-udaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datafusion import udaf, SessionContext, Accumulator
import pyarrow as pa


# Define a user-defined aggregation function (UDAF)
class MyAccumulator(Accumulator):
"""
Interface of a user-defined accumulation.
"""

def __init__(self):
self._sum = pa.scalar(0.0)

def update(self, values: pa.Array) -> None:
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
self._sum = pa.scalar(
self._sum.as_py() + pa.compute.sum(values).as_py()
)

def merge(self, states: pa.Array) -> None:
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
self._sum = pa.scalar(
self._sum.as_py() + pa.compute.sum(states).as_py()
)

def state(self) -> pa.Array:
return pa.array([self._sum.as_py()])

def evaluate(self) -> pa.Scalar:
return self._sum


my_udaf = udaf(
MyAccumulator,
pa.float64(),
pa.float64(),
[pa.float64()],
"stable",
# This will be the name of the UDAF in SQL
# If not specified it will by default the same as accumulator class name
# name="my_accumulator",
)

# Create a context
ctx = SessionContext()

# Create a datafusion DataFrame from a Python dictionary
source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]})
# Dataframe:
# +---+---+
# | a | b |
# +---+---+
# | 1 | 4 |
# | 1 | 5 |
# | 3 | 6 |
# +---+---+

# Register UDF for use in SQL
ctx.register_udaf(my_udaf)

# Query the DataFrame using SQL
table_name = ctx.catalog().database().names().pop()
result_df = ctx.sql(
f"select a, my_accumulator(b) as b_aggregated from {table_name} group by a order by a"
)
Comment on lines +78 to +82
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it would be nice if the table name can be set in the from... context functions so table_name is known, have created ticket #259 as follow-up task

# Dataframe:
# +---+--------------+
# | a | b_aggregated |
# +---+--------------+
# | 1 | 9 |
# | 3 | 6 |
# +---+--------------+
assert result_df.to_pydict()["a"] == [1, 3]
assert result_df.to_pydict()["b_aggregated"] == [9, 6]
65 changes: 65 additions & 0 deletions examples/sql-using-python-udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datafusion import udf, SessionContext
import pyarrow as pa


# Define a user-defined function (UDF)
def is_null(array: pa.Array) -> pa.Array:
return array.is_null()


is_null_arr = udf(
is_null,
[pa.int64()],
pa.bool_(),
"stable",
# This will be the name of the UDF in SQL
# If not specified it will by default the same as Python function name
name="is_null",
)

# Create a context
ctx = SessionContext()

# Create a datafusion DataFrame from a Python dictionary
source_df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]})
# Dataframe:
# +---+---+
# | a | b |
# +---+---+
# | 1 | 4 |
# | 2 | |
# | 3 | 6 |
# +---+---+

# Register UDF for use in SQL
ctx.register_udf(is_null_arr)

# Query the DataFrame using SQL
table_name = ctx.catalog().database().names().pop()
result_df = ctx.sql(f"select a, is_null(b) as b_is_null from {table_name}")
# Dataframe:
# +---+-----------+
# | a | b_is_null |
# +---+-----------+
# | 1 | false |
# | 2 | true |
# | 3 | false |
# +---+-----------+
assert result_df.to_pydict()["b_is_null"] == [False, True, False]