1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: { ArrayIter , ArrayRef , AsArray , Int64Array , RecordBatch , StringArray } ;
19- use arrow:: compute:: kernels:: cmp:: eq;
18+ //! This example shows how to create and use "Async UDFs" in DataFusion.
19+ //!
20+ //! Async UDFs allow you to perform asynchronous operations, such as
21+ //! making network requests. This can be used for tasks like fetching
22+ //! data from an external API such as a LLM service or an external database.
23+
24+ use arrow:: array:: { ArrayRef , BooleanArray , Int64Array , RecordBatch , StringArray } ;
2025use arrow_schema:: { DataType , Field , Schema } ;
2126use async_trait:: async_trait;
27+ use datafusion:: assert_batches_eq;
28+ use datafusion:: common:: cast:: as_string_view_array;
2229use datafusion:: common:: error:: Result ;
23- use datafusion:: common:: types :: { logical_int64 , logical_string } ;
30+ use datafusion:: common:: not_impl_err ;
2431use datafusion:: common:: utils:: take_function_args;
25- use datafusion:: common:: { internal_err, not_impl_err} ;
2632use datafusion:: config:: ConfigOptions ;
33+ use datafusion:: execution:: SessionStateBuilder ;
2734use datafusion:: logical_expr:: async_udf:: { AsyncScalarUDF , AsyncScalarUDFImpl } ;
2835use datafusion:: logical_expr:: {
29- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
30- TypeSignatureClass , Volatility ,
36+ ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
3137} ;
32- use datafusion:: logical_expr_common:: signature:: Coercion ;
33- use datafusion:: physical_expr_common:: datum:: apply_cmp;
34- use datafusion:: prelude:: SessionContext ;
35- use log:: trace;
38+ use datafusion:: prelude:: { SessionConfig , SessionContext } ;
3639use std:: any:: Any ;
3740use std:: sync:: Arc ;
3841
3942#[ tokio:: main]
4043async fn main ( ) -> Result < ( ) > {
41- let ctx: SessionContext = SessionContext :: new ( ) ;
42-
43- let async_upper = AsyncUpper :: new ( ) ;
44- let udf = AsyncScalarUDF :: new ( Arc :: new ( async_upper) ) ;
45- ctx. register_udf ( udf. into_scalar_udf ( ) ) ;
46- let async_equal = AsyncEqual :: new ( ) ;
44+ // Use a hard coded parallelism level of 4 so the explain plan
45+ // is consistent across machines.
46+ let config = SessionConfig :: new ( ) . with_target_partitions ( 4 ) ;
47+ let ctx =
48+ SessionContext :: from ( SessionStateBuilder :: new ( ) . with_config ( config) . build ( ) ) ;
49+
50+ // Similarly to regular UDFs, you create an AsyncScalarUDF by implementing
51+ // `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`.
52+ let async_equal = AskLLM :: new ( ) ;
4753 let udf = AsyncScalarUDF :: new ( Arc :: new ( async_equal) ) ;
54+
55+ // Async UDFs are registered with the SessionContext, using the same
56+ // `register_udf` method as regular UDFs.
4857 ctx. register_udf ( udf. into_scalar_udf ( ) ) ;
58+
59+ // Create a table named 'animal' with some sample data
4960 ctx. register_batch ( "animal" , animal ( ) ?) ?;
5061
51- // use Async UDF in the projection
52- // +---------------+----------------------------------------------------------------------------------------+
53- // | plan_type | plan |
54- // +---------------+----------------------------------------------------------------------------------------+
55- // | logical_plan | Projection: async_equal(a.id, Int64(1)) |
56- // | | SubqueryAlias: a |
57- // | | TableScan: animal projection=[id] |
58- // | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] |
59- // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
60- // | | CoalesceBatchesExec: target_batch_size=8192 |
61- // | | DataSourceExec: partitions=1, partition_sizes=[1] |
62- // | | |
63- // +---------------+----------------------------------------------------------------------------------------+
64- ctx. sql ( "explain select async_equal(a.id, 1) from animal a" )
62+ // You can use the async UDF as normal in SQL queries
63+ //
64+ // Note: Async UDFs can currently be used in the select list and filter conditions.
65+ let results = ctx
66+ . sql ( "select * from animal a where ask_llm(a.name, 'Is this animal furry?')" )
6567 . await ?
66- . show ( )
68+ . collect ( )
6769 . await ?;
6870
69- // +----------------------------+
70- // | async_equal(a.id,Int64(1)) |
71- // +----------------------------+
72- // | true |
73- // | false |
74- // | false |
75- // | false |
76- // | false |
77- // +----------------------------+
78- ctx. sql ( "select async_equal(a.id, 1) from animal a" )
71+ assert_batches_eq ! (
72+ [
73+ "+----+------+" ,
74+ "| id | name |" ,
75+ "+----+------+" ,
76+ "| 1 | cat |" ,
77+ "| 2 | dog |" ,
78+ "+----+------+" ,
79+ ] ,
80+ & results
81+ ) ;
82+
83+ // While the interface is the same for both normal and async UDFs, you can
84+ // use `EXPLAIN` output to see that the async UDF uses a special
85+ // `AsyncFuncExec` node in the physical plan:
86+ let results = ctx
87+ . sql ( "explain select * from animal a where ask_llm(a.name, 'Is this animal furry?')" )
7988 . await ?
80- . show ( )
89+ . collect ( )
8190 . await ?;
8291
83- // use Async UDF in the filter
84- // +---------------+--------------------------------------------------------------------------------------------+
85- // | plan_type | plan |
86- // +---------------+--------------------------------------------------------------------------------------------+
87- // | logical_plan | SubqueryAlias: a |
88- // | | Filter: async_equal(animal.id, Int64(1)) |
89- // | | TableScan: animal projection=[id, name] |
90- // | physical_plan | CoalesceBatchesExec: target_batch_size=8192 |
91- // | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |
92- // | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |
93- // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
94- // | | CoalesceBatchesExec: target_batch_size=8192 |
95- // | | DataSourceExec: partitions=1, partition_sizes=[1] |
96- // | | |
97- // +---------------+--------------------------------------------------------------------------------------------+
98- ctx. sql ( "explain select * from animal a where async_equal(a.id, 1)" )
99- . await ?
100- . show ( )
101- . await ?;
102-
103- // +----+------+
104- // | id | name |
105- // +----+------+
106- // | 1 | cat |
107- // +----+------+
108- ctx. sql ( "select * from animal a where async_equal(a.id, 1)" )
109- . await ?
110- . show ( )
111- . await ?;
92+ assert_batches_eq ! (
93+ [
94+ "+---------------+--------------------------------------------------------------------------------------------------------------------------------+" ,
95+ "| plan_type | plan |" ,
96+ "+---------------+--------------------------------------------------------------------------------------------------------------------------------+" ,
97+ "| logical_plan | SubqueryAlias: a |" ,
98+ "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\" Is this animal furry?\" )) |" ,
99+ "| | TableScan: animal projection=[id, name] |" ,
100+ "| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |" ,
101+ "| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |" ,
102+ "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |" ,
103+ "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |" ,
104+ "| | CoalesceBatchesExec: target_batch_size=8192 |" ,
105+ "| | DataSourceExec: partitions=1, partition_sizes=[1] |" ,
106+ "| | |" ,
107+ "+---------------+--------------------------------------------------------------------------------------------------------------------------------+" ,
108+ ] ,
109+ & results
110+ ) ;
112111
113112 Ok ( ( ) )
114113}
115114
115+ /// Returns a sample `RecordBatch` representing an "animal" table with two columns:
116116fn animal ( ) -> Result < RecordBatch > {
117117 let schema = Arc :: new ( Schema :: new ( vec ! [
118118 Field :: new( "id" , DataType :: Int64 , false ) ,
@@ -127,118 +127,45 @@ fn animal() -> Result<RecordBatch> {
127127 Ok ( RecordBatch :: try_new ( schema, vec ! [ id_array, name_array] ) ?)
128128}
129129
130+ /// An async UDF that simulates asking a large language model (LLM) service a
131+ /// question based on the content of two columns. The UDF will return a boolean
132+ /// indicating whether the LLM thinks the first argument matches the question in
133+ /// the second argument.
134+ ///
135+ /// Since this is a simplified example, it does not call an LLM service, but
136+ /// could be extended to do so in a real-world scenario.
130137#[ derive( Debug ) ]
131- pub struct AsyncUpper {
132- signature : Signature ,
133- }
134-
135- impl Default for AsyncUpper {
136- fn default ( ) -> Self {
137- Self :: new ( )
138- }
139- }
140-
141- impl AsyncUpper {
142- pub fn new ( ) -> Self {
143- Self {
144- signature : Signature :: new (
145- TypeSignature :: Coercible ( vec ! [ Coercion :: Exact {
146- desired_type: TypeSignatureClass :: Native ( logical_string( ) ) ,
147- } ] ) ,
148- Volatility :: Volatile ,
149- ) ,
150- }
151- }
152- }
153-
154- #[ async_trait]
155- impl ScalarUDFImpl for AsyncUpper {
156- fn as_any ( & self ) -> & dyn Any {
157- self
158- }
159-
160- fn name ( & self ) -> & str {
161- "async_upper"
162- }
163-
164- fn signature ( & self ) -> & Signature {
165- & self . signature
166- }
167-
168- fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
169- Ok ( DataType :: Utf8 )
170- }
171-
172- fn invoke_with_args ( & self , _args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
173- not_impl_err ! ( "AsyncUpper can only be called from async contexts" )
174- }
175- }
176-
177- #[ async_trait]
178- impl AsyncScalarUDFImpl for AsyncUpper {
179- fn ideal_batch_size ( & self ) -> Option < usize > {
180- Some ( 10 )
181- }
182-
183- async fn invoke_async_with_args (
184- & self ,
185- args : ScalarFunctionArgs ,
186- _option : & ConfigOptions ,
187- ) -> Result < ArrayRef > {
188- trace ! ( "Invoking async_upper with args: {:?}" , args) ;
189- let value = & args. args [ 0 ] ;
190- let result = match value {
191- ColumnarValue :: Array ( array) => {
192- let string_array = array. as_string :: < i32 > ( ) ;
193- let iter = ArrayIter :: new ( string_array) ;
194- let result = iter
195- . map ( |string| string. map ( |s| s. to_uppercase ( ) ) )
196- . collect :: < StringArray > ( ) ;
197- Arc :: new ( result) as ArrayRef
198- }
199- _ => return internal_err ! ( "Expected a string argument, got {:?}" , value) ,
200- } ;
201- Ok ( result)
202- }
203- }
204-
205- #[ derive( Debug ) ]
206- struct AsyncEqual {
138+ struct AskLLM {
207139 signature : Signature ,
208140}
209141
210- impl Default for AsyncEqual {
142+ impl Default for AskLLM {
211143 fn default ( ) -> Self {
212144 Self :: new ( )
213145 }
214146}
215147
216- impl AsyncEqual {
148+ impl AskLLM {
217149 pub fn new ( ) -> Self {
218150 Self {
219- signature : Signature :: new (
220- TypeSignature :: Coercible ( vec ! [
221- Coercion :: Exact {
222- desired_type: TypeSignatureClass :: Native ( logical_int64( ) ) ,
223- } ,
224- Coercion :: Exact {
225- desired_type: TypeSignatureClass :: Native ( logical_int64( ) ) ,
226- } ,
227- ] ) ,
151+ signature : Signature :: exact (
152+ vec ! [ DataType :: Utf8View , DataType :: Utf8View ] ,
228153 Volatility :: Volatile ,
229154 ) ,
230155 }
231156 }
232157}
233158
234- #[ async_trait]
235- impl ScalarUDFImpl for AsyncEqual {
159+ /// All async UDFs implement the `ScalarUDFImpl` trait, which provides the basic
160+ /// information for the function, such as its name, signature, and return type.
161+ /// [async_trait]
162+ impl ScalarUDFImpl for AskLLM {
236163 fn as_any ( & self ) -> & dyn Any {
237164 self
238165 }
239166
240167 fn name ( & self ) -> & str {
241- "async_equal "
168+ "ask_llm "
242169 }
243170
244171 fn signature ( & self ) -> & Signature {
@@ -249,19 +176,64 @@ impl ScalarUDFImpl for AsyncEqual {
249176 Ok ( DataType :: Boolean )
250177 }
251178
179+ /// Since this is an async UDF, the `invoke_with_args` method will not be
180+ /// called directly.
252181 fn invoke_with_args ( & self , _args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
253- not_impl_err ! ( "AsyncEqual can only be called from async contexts" )
182+ not_impl_err ! ( "AskLLM can only be called from async contexts" )
254183 }
255184}
256185
186+ /// In addition to [`ScalarUDFImpl`], we also need to implement the
187+ /// [`AsyncScalarUDFImpl`] trait.
257188#[ async_trait]
258- impl AsyncScalarUDFImpl for AsyncEqual {
189+ impl AsyncScalarUDFImpl for AskLLM {
190+ /// The `invoke_async_with_args` method is similar to `invoke_with_args`,
191+ /// but it returns a `Future` that resolves to the result.
192+ ///
193+ /// Since this signature is `async`, it can do any `async` operations, such
194+ /// as network requests. This method is run on the same tokio `Runtime` that
195+ /// is processing the query, so you may wish to make actual network requests
196+ /// on a different `Runtime`, as explained in the `thread_pools.rs` example
197+ /// in this directory.
259198 async fn invoke_async_with_args (
260199 & self ,
261200 args : ScalarFunctionArgs ,
262201 _option : & ConfigOptions ,
263202 ) -> Result < ArrayRef > {
264- let [ arg1, arg2] = take_function_args ( self . name ( ) , & args. args ) ?;
265- apply_cmp ( arg1, arg2, eq) ?. to_array ( args. number_rows )
203+ // in a real UDF you would likely want to special case constant
204+ // arguments to improve performance, but this example converts the
205+ // arguments to arrays for simplicity.
206+ let args = ColumnarValue :: values_to_arrays ( & args. args ) ?;
207+ let [ content_column, question_column] = take_function_args ( self . name ( ) , args) ?;
208+
209+ // In a real function, you would use a library such as `reqwest` here to
210+ // make an async HTTP request. Credentials and other configurations can
211+ // be supplied via the `ConfigOptions` parameter.
212+
213+ // In this example, we will simulate the LLM response by comparing the two
214+ // input arguments using some static strings
215+ let content_column = as_string_view_array ( & content_column) ?;
216+ let question_column = as_string_view_array ( & question_column) ?;
217+
218+ let result_array: BooleanArray = content_column
219+ . iter ( )
220+ . zip ( question_column. iter ( ) )
221+ . map ( |( a, b) | {
222+ // If either value is null, return None
223+ let a = a?;
224+ let b = b?;
225+ // Simulate an LLM response by checking the arguments to some
226+ // hardcoded conditions.
227+ if a. contains ( "cat" ) && b. contains ( "furry" )
228+ || a. contains ( "dog" ) && b. contains ( "furry" )
229+ {
230+ Some ( true )
231+ } else {
232+ Some ( false )
233+ }
234+ } )
235+ . collect ( ) ;
236+
237+ Ok ( Arc :: new ( result_array) )
266238 }
267239}
0 commit comments