diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 6ebf88a0b671..d530b9abe030 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -40,6 +40,7 @@ use std::sync::Arc; /// the power of the second argument `a^b`. /// /// To do so, we must implement the `ScalarUDFImpl` trait. +#[derive(Debug, Clone)] struct PowUdf { signature: Signature, aliases: Vec, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ebf4d3143c12..5617d217eb9f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1948,6 +1948,7 @@ mod test { ); // UDF + #[derive(Debug)] struct TestScalarUDF { signature: Signature, } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 7b3f65248586..0491750d18a9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -984,6 +984,16 @@ pub struct SimpleScalarUDF { fun: ScalarFunctionImplementation, } +impl Debug for SimpleScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + impl SimpleScalarUDF { /// Create a new `SimpleScalarUDF` from a name, input types, return type and /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 2ec80a4a9ea1..8b35d5834c61 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -35,48 +35,26 @@ use std::sync::Arc; /// functions you supply such name, type signature, return type, and actual /// implementation. /// -/// /// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. /// /// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. /// +/// # API Note +/// +/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`create_udf`]: crate::expr_fn::create_udf /// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs /// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ScalarUDF { - /// The name of the function - name: String, - /// The signature (the types of arguments that are supported) - signature: Signature, - /// Function that returns the return type given the argument types - return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - fun: ScalarFunctionImplementation, - /// Optional aliases for the function. This list should NOT include the value of `name` as well - aliases: Vec, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } + inner: Arc, } impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -84,8 +62,8 @@ impl Eq for ScalarUDF {} impl std::hash::Hash for ScalarUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -101,13 +79,12 @@ impl ScalarUDF { return_type: &ReturnTypeFunction, fun: &ScalarFunctionImplementation, ) -> Self { - Self { + Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), - aliases: vec![], - } + }) } /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object @@ -115,37 +92,24 @@ impl ScalarUDF { /// Note this is the same as using the `From` impl (`ScalarUDF::from`) pub fn new_from_impl(fun: F) -> ScalarUDF where - F: ScalarUDFImpl + Send + Sync + 'static, + F: ScalarUDFImpl + 'static, { - // TODO change the internal implementation to use the trait object - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let func: ScalarFunctionImplementation = - Arc::new(move |args| captured_self.invoke(args)); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - fun: func, - aliases: arc_fun.aliases().to_vec(), + inner: Arc::new(fun), } } - /// Adds additional names that can be used to invoke this function, in addition to `name` - pub fn with_aliases( - mut self, - aliases: impl IntoIterator, - ) -> Self { - self.aliases - .extend(aliases.into_iter().map(|s| s.to_string())); - self + /// Return the underlying [`ScalarUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -159,31 +123,46 @@ impl ScalarUDF { )) } - /// Returns this function's name + /// Returns this function's name. + /// + /// See [`ScalarUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } - /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + /// Returns the aliases for this function. + /// + /// See [`ScalarUDF::with_aliases`] for more details pub fn aliases(&self) -> &[String] { - &self.aliases + self.inner.aliases() } - /// Returns this function's [`Signature`] (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted). + /// + /// See [`ScalarUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } - /// The datatype this function returns given the input argument input types + /// The datatype this function returns given the input argument input types. + /// + /// See [`ScalarUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) + } + + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke`] for more details. + pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) } - /// Return an [`Arc`] to the function implementation + /// Returns a `ScalarFunctionImplementation` that can invoke the function + /// during execution pub fn fun(&self) -> ScalarFunctionImplementation { - self.fun.clone() + let captured = self.inner.clone(); + Arc::new(move |args| captured.invoke(args)) } } @@ -213,6 +192,7 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// #[derive(Debug)] /// struct AddOne { /// signature: Signature /// }; @@ -246,7 +226,7 @@ where /// // Call the function `add_one(col)` /// let expr = add_one.call(vec![col("a")]); /// ``` -pub trait ScalarUDFImpl { +pub trait ScalarUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -292,3 +272,106 @@ pub trait ScalarUDFImpl { &[] } } + +/// ScalarUDF that adds an alias to the underlying function. It is better to +/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedScalarUDFImpl { + inner: ScalarUDF, + aliases: Vec, +} + +impl AliasedScalarUDFImpl { + pub fn new( + inner: ScalarUDF, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl ScalarUDFImpl for AliasedScalarUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers +/// of the older API (see +/// for more details) +struct ScalarUdfLegacyWrapper { + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUdfLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl ScalarUDFImpl for ScalarUdfLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } + + fn aliases(&self) -> &[String] { + &[] + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4d54dad99670..6f1da5f4e6d9 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -811,6 +811,7 @@ mod test { static TEST_SIGNATURE: OnceLock = OnceLock::new(); + #[derive(Debug, Clone, Default)] struct TestScalarUDF {} impl ScalarUDFImpl for TestScalarUDF { fn as_any(&self) -> &dyn Any { diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 0ec1cf3f256b..9daa9eb173dd 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -36,7 +36,7 @@ pub fn create_physical_expr( Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun().clone(), + fun.fun(), input_phy_exprs.to_vec(), fun.return_type(&input_exprs_types)?, None,