From 5d1d500f1821010201d66c57cb34c2a1e4e2975c Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 2 Mar 2024 10:15:29 +0800 Subject: [PATCH] change simplify signature Signed-off-by: jayzhan211 --- .../user_defined_scalar_functions.rs | 9 ++++++--- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/simplify.rs | 9 +++++++++ datafusion/expr/src/udf.rs | 19 +++++-------------- .../function_simplifier.rs | 11 +++-------- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index d4cd56ac6d5ce..43f4546c39b69 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -24,14 +24,16 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; -use datafusion_common::DFSchemaRef; +use datafusion_common::DFSchema; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, plan_err, DataFusionError, ExprSchema, Result, ScalarValue, }; +use datafusion_expr::simplify::Simplified; +use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, - LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Simplified, Volatility, + LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use rand::{thread_rng, Rng}; @@ -529,8 +531,9 @@ impl ScalarUDFImpl for CastToI64UDF { Ok(DataType::Int64) } // Wrap with Expr::Cast() to Int64 - fn simplify(&self, args: &[Expr], schema: DFSchemaRef) -> Result { + fn simplify(&self, args: &[Expr], info: &dyn SimplifyInfo) -> Result { let e = args[0].to_owned(); + let schema = info.schema().unwrap_or_else(|| DFSchema::empty().into()); let casted_expr = e.cast_to(&DataType::Int64, schema.as_ref())?; Ok(Simplified::Rewritten(casted_expr)) } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e1ea51dc0324b..a297f2dc78866 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -83,7 +83,7 @@ pub use signature::{ }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateUDF, AggregateUDFImpl}; -pub use udf::{ScalarUDF, ScalarUDFImpl, Simplified}; +pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 89b1623aa606a..3c34f9d74a5b6 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -140,3 +140,12 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { self.props } } + +/// Was the expression simplified? +pub enum Simplified { + /// The function call was simplified to an entirely new Expr + Rewritten(Expr), + /// the function call could not be simplified, and the arguments + /// are return unmodified + Original, +} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 5a7c03574360a..c06b81b263fe0 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,29 +17,20 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::simplify::{Simplified, SimplifyInfo}; use crate::ExprSchemable; use crate::{ ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, ExprSchema, Result}; +use datafusion_common::{ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -// TODO(In this PR): Move to simplify.rs -/// Was the expression simplified? -pub enum Simplified { - /// The function call was simplified to an entirely new Expr - Rewritten(Expr), - /// the function call could not be simplified, and the arguments - /// are return unmodified - Original, -} - /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This @@ -173,8 +164,8 @@ impl ScalarUDF { /// Do the function rewrite /// /// See [`ScalarUDFImpl::simplify`] for more details. - pub fn simplify(&self, args: &[Expr], schema: DFSchemaRef) -> Result { - self.inner.simplify(args, schema) + pub fn simplify(&self, args: &[Expr], info: &dyn SimplifyInfo) -> Result { + self.inner.simplify(args, info) } /// Invoke the function on `args`, returning the appropriate result. @@ -358,7 +349,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { // Do the function rewrite. // 'args': The arguments of the function // 'schema': The schema of the function - fn simplify(&self, _args: &[Expr], _schema: DFSchemaRef) -> Result { + fn simplify(&self, _args: &[Expr], _info: &dyn SimplifyInfo) -> Result { Ok(Simplified::Original) } } diff --git a/datafusion/optimizer/src/simplify_expressions/function_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/function_simplifier.rs index 6ff598035a75c..dc8de3b48dabf 100644 --- a/datafusion/optimizer/src/simplify_expressions/function_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/function_simplifier.rs @@ -18,9 +18,9 @@ //! This module implements a rule that do function simplification. use datafusion_common::tree_node::TreeNodeRewriter; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; +use datafusion_expr::simplify::Simplified; use datafusion_expr::simplify::SimplifyInfo; -use datafusion_expr::Simplified; use datafusion_expr::{expr::ScalarFunction, Expr, ScalarFunctionDefinition}; pub(super) struct FunctionSimplifier<'a, S> { @@ -42,12 +42,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for FunctionSimplifier<'a, S> { args, }) = &expr { - let schema = self - .info - .schema() - .unwrap_or_else(|| DFSchema::empty().into()); - - let simplified_expr = udf.simplify(args, schema)?; + let simplified_expr = udf.simplify(args, self.info)?; match simplified_expr { Simplified::Original => Ok(expr), Simplified::Rewritten(expr) => Ok(expr),