Skip to content

Commit 6710e6d

Browse files
milenkovicmalamb
andauthored
Add example for FunctionFactory (#9482)
* `FunctionFactory` usage example * update test to use the same function factory * Add entry to examples/README.md * Add SessionContext::with_function_factory * Update doc and example * clippy --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent afddb32 commit 6710e6d

File tree

4 files changed

+408
-62
lines changed

4 files changed

+408
-62
lines changed

datafusion-examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ cargo run --example csv_sql
5454
- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde
5555
- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and analyze `Expr`s
5656
- [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients
57+
- [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros
5758
- [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function
5859
- [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es
5960
- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::error::Result;
19+
use datafusion::execution::config::SessionConfig;
20+
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionContext};
21+
use datafusion_common::tree_node::{Transformed, TreeNode};
22+
use datafusion_common::{exec_err, internal_err, DataFusionError};
23+
use datafusion_expr::simplify::ExprSimplifyResult;
24+
use datafusion_expr::simplify::SimplifyInfo;
25+
use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature};
26+
use std::result::Result as RResult;
27+
use std::sync::Arc;
28+
29+
/// This example shows how to utilize [FunctionFactory] to implement simple
30+
/// SQL-macro like functions using a `CREATE FUNCTION` statement. The same
31+
/// functionality can support functions defined in any language or library.
32+
///
33+
/// Apart from [FunctionFactory], this example covers
34+
/// [ScalarUDFImpl::simplify()] which is often used at the same time, to replace
35+
/// a function call with another expression at rutime.
36+
///
37+
/// This example is rather simple and does not cover all cases required for a
38+
/// real implementation.
39+
#[tokio::main]
40+
async fn main() -> Result<()> {
41+
// First we must configure the SessionContext with our function factory
42+
let ctx = SessionContext::new()
43+
// register custom function factory
44+
.with_function_factory(Arc::new(CustomFunctionFactory::default()));
45+
46+
// With the function factory, we can now call `CREATE FUNCTION` SQL functions
47+
48+
// Let us register a function called f which takes a single argument and
49+
// returns that value plus one
50+
let sql = r#"
51+
CREATE FUNCTION f1(BIGINT)
52+
RETURNS BIGINT
53+
RETURN $1 + 1
54+
"#;
55+
56+
ctx.sql(sql).await?.show().await?;
57+
58+
// Now, let us register a function called f2 which takes two arguments and
59+
// returns the first argument added to the result of calling f1 on that
60+
// argument
61+
let sql = r#"
62+
CREATE FUNCTION f2(BIGINT, BIGINT)
63+
RETURNS BIGINT
64+
RETURN $1 + f1($2)
65+
"#;
66+
67+
ctx.sql(sql).await?.show().await?;
68+
69+
// Invoke f2, and we expect to see 1 + (1 + 2) = 4
70+
// Note this function works on columns as well as constants.
71+
let sql = r#"
72+
SELECT f2(1, 2)
73+
"#;
74+
ctx.sql(sql).await?.show().await?;
75+
76+
// Now we clean up the session by dropping the functions
77+
ctx.sql("DROP FUNCTION f1").await?.show().await?;
78+
ctx.sql("DROP FUNCTION f2").await?.show().await?;
79+
80+
Ok(())
81+
}
82+
83+
/// This is our FunctionFactory that is responsible for converting `CREATE
84+
/// FUNCTION` statements into function instances
85+
#[derive(Debug, Default)]
86+
struct CustomFunctionFactory {}
87+
88+
#[async_trait::async_trait]
89+
impl FunctionFactory for CustomFunctionFactory {
90+
/// This function takes the parsed `CREATE FUNCTION` statement and returns
91+
/// the function instance.
92+
async fn create(
93+
&self,
94+
_state: &SessionConfig,
95+
statement: CreateFunction,
96+
) -> Result<RegisterFunction> {
97+
let f: ScalarFunctionWrapper = statement.try_into()?;
98+
99+
Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f))))
100+
}
101+
}
102+
103+
/// this function represents the newly created execution engine.
104+
#[derive(Debug)]
105+
struct ScalarFunctionWrapper {
106+
/// The text of the function body, `$1 + f1($2)` in our example
107+
name: String,
108+
expr: Expr,
109+
signature: Signature,
110+
return_type: arrow_schema::DataType,
111+
}
112+
113+
impl ScalarUDFImpl for ScalarFunctionWrapper {
114+
fn as_any(&self) -> &dyn std::any::Any {
115+
self
116+
}
117+
118+
fn name(&self) -> &str {
119+
&self.name
120+
}
121+
122+
fn signature(&self) -> &datafusion_expr::Signature {
123+
&self.signature
124+
}
125+
126+
fn return_type(
127+
&self,
128+
_arg_types: &[arrow_schema::DataType],
129+
) -> Result<arrow_schema::DataType> {
130+
Ok(self.return_type.clone())
131+
}
132+
133+
fn invoke(
134+
&self,
135+
_args: &[datafusion_expr::ColumnarValue],
136+
) -> Result<datafusion_expr::ColumnarValue> {
137+
// Since this function is always simplified to another expression, it
138+
// should never actually be invoked
139+
internal_err!("This function should not get invoked!")
140+
}
141+
142+
/// The simplify function is called to simply a call such as `f2(2)`. This
143+
/// function parses the string and returns the resulting expression
144+
fn simplify(
145+
&self,
146+
args: Vec<Expr>,
147+
_info: &dyn SimplifyInfo,
148+
) -> Result<ExprSimplifyResult> {
149+
let replacement = Self::replacement(&self.expr, &args)?;
150+
151+
Ok(ExprSimplifyResult::Simplified(replacement))
152+
}
153+
154+
fn aliases(&self) -> &[String] {
155+
&[]
156+
}
157+
158+
fn monotonicity(&self) -> Result<Option<datafusion_expr::FuncMonotonicity>> {
159+
Ok(None)
160+
}
161+
}
162+
163+
impl ScalarFunctionWrapper {
164+
// replaces placeholders such as $1 with actual arguments (args[0]
165+
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
166+
let result = expr.clone().transform(&|e| {
167+
let r = match e {
168+
Expr::Placeholder(placeholder) => {
169+
let placeholder_position =
170+
Self::parse_placeholder_identifier(&placeholder.id)?;
171+
if placeholder_position < args.len() {
172+
Transformed::yes(args[placeholder_position].clone())
173+
} else {
174+
exec_err!(
175+
"Function argument {} not provided, argument missing!",
176+
placeholder.id
177+
)?
178+
}
179+
}
180+
_ => Transformed::no(e),
181+
};
182+
183+
Ok(r)
184+
})?;
185+
186+
Ok(result.data)
187+
}
188+
// Finds placeholder identifier such as `$X` format where X >= 1
189+
fn parse_placeholder_identifier(placeholder: &str) -> Result<usize> {
190+
if let Some(value) = placeholder.strip_prefix('$') {
191+
Ok(value.parse().map(|v: usize| v - 1).map_err(|e| {
192+
DataFusionError::Execution(format!(
193+
"Placeholder `{}` parsing error: {}!",
194+
placeholder, e
195+
))
196+
})?)
197+
} else {
198+
exec_err!("Placeholder should start with `$`!")
199+
}
200+
}
201+
}
202+
203+
/// This impl block creates a scalar function from
204+
/// a parsed `CREATE FUNCTION` statement (`CreateFunction`)
205+
impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
206+
type Error = DataFusionError;
207+
208+
fn try_from(definition: CreateFunction) -> RResult<Self, Self::Error> {
209+
Ok(Self {
210+
name: definition.name,
211+
expr: definition
212+
.params
213+
.return_
214+
.expect("Expression has to be defined!"),
215+
return_type: definition
216+
.return_type
217+
.expect("Return type has to be defined!"),
218+
signature: Signature::exact(
219+
definition
220+
.args
221+
.unwrap_or_default()
222+
.into_iter()
223+
.map(|a| a.data_type)
224+
.collect(),
225+
definition
226+
.params
227+
.behavior
228+
.unwrap_or(datafusion_expr::Volatility::Volatile),
229+
),
230+
})
231+
}
232+
}

datafusion/core/src/execution/context/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,15 @@ impl SessionContext {
349349
self.session_start_time
350350
}
351351

352+
/// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
353+
pub fn with_function_factory(
354+
self,
355+
function_factory: Arc<dyn FunctionFactory>,
356+
) -> Self {
357+
self.state.write().set_function_factory(function_factory);
358+
self
359+
}
360+
352361
/// Registers the [`RecordBatch`] as the specified table name
353362
pub fn register_batch(
354363
&self,
@@ -1659,6 +1668,11 @@ impl SessionState {
16591668
self
16601669
}
16611670

1671+
/// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements
1672+
pub fn set_function_factory(&mut self, function_factory: Arc<dyn FunctionFactory>) {
1673+
self.function_factory = Some(function_factory);
1674+
}
1675+
16621676
/// Replace the extension [`SerializerRegistry`]
16631677
pub fn with_serializer_registry(
16641678
mut self,

0 commit comments

Comments
 (0)