Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a ScalarUDFImpl::simplfy() API, move SimplifyInfo et al to datafusion_expr #9304

Merged
merged 30 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b787dfb
first draft
jayzhan211 Feb 21, 2024
7259275
clippy
jayzhan211 Feb 21, 2024
4d98121
add comments
jayzhan211 Feb 21, 2024
bacc966
move to optimize rule
jayzhan211 Feb 23, 2024
3199bca
cleanup
jayzhan211 Feb 23, 2024
63648be
fix explain test
jayzhan211 Feb 23, 2024
83fc9d8
move to simplifier
jayzhan211 Feb 27, 2024
c73126a
pass with schema
jayzhan211 Feb 28, 2024
dd99362
fix explain
jayzhan211 Feb 28, 2024
0b66ed3
fix doc
jayzhan211 Feb 28, 2024
0798274
move to expr
jayzhan211 Mar 2, 2024
5fdf177
change simplify signature
jayzhan211 Mar 2, 2024
ab66a19
cleanup
jayzhan211 Mar 2, 2024
7c7b654
cleanup
jayzhan211 Mar 2, 2024
cbefb3c
fix doc
jayzhan211 Mar 2, 2024
7718251
fix doc
jayzhan211 Mar 2, 2024
5cdae92
Update datafusion/expr/src/udf.rs
alamb Mar 2, 2024
4886ba5
Add backwards compatibile uses, inline FunctionSimplifier, rename to …
alamb Mar 2, 2024
ed6a04b
Remove DFSchema from SimplifyInfo
alamb Mar 2, 2024
f6848d8
Avoid requiring argument copies
alamb Mar 2, 2024
a8541ff
Merge remote-tracking branch 'apache/main' into simply-udf
alamb Mar 2, 2024
18f8371
Improve docs
alamb Mar 2, 2024
bfb54a0
fix link
alamb Mar 2, 2024
33aa7ff
fix doc test
alamb Mar 2, 2024
24adcbf
Update datafusion/physical-expr/src/lib.rs
alamb Mar 3, 2024
8cab80b
Merge remote-tracking branch 'apache/main' into simply-udf
alamb Mar 3, 2024
550cbc4
Merge remote-tracking branch 'apache/main' into simply-udf
alamb Mar 4, 2024
fea82cb
Change example simplify to always simplify its argument
alamb Mar 4, 2024
4e9eb70
Clarify comment
alamb Mar 5, 2024
fdec54c
Merge remote-tracking branch 'apache/main' into simply-udf
alamb Mar 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ use arrow::record_batch::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::{DFField, DFSchema};
use datafusion::error::Result;
use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use datafusion::physical_expr::execution_props::ExecutionProps;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{
analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr,
};
use datafusion::prelude::*;
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{plan_err, ScalarValue};
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{Expr, TableType};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use std::fs::File;
use std::io::Seek;
use std::path::Path;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use object_store::path::Path;
use object_store::{ObjectMeta, ObjectStore};

Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/datasource/physical_plan/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,13 +800,14 @@ mod tests {
ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray,
StructArray,
};

use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder};
use arrow::record_batch::RecordBatch;
use arrow_schema::Fields;
use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{col, lit, when, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{TimeZone, Utc};
use futures::StreamExt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ mod test {
use super::*;
use arrow::datatypes::Field;
use datafusion_common::ToDFSchema;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{cast, col, lit, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use parquet::arrow::parquet_to_arrow_schema;
use parquet::file::reader::{FileReader, SerializedFileReader};
use rand::prelude::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ mod tests {
use arrow::datatypes::Schema;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{Result, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{cast, col, lit, Expr};
use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use parquet::arrow::arrow_to_parquet_schema;
use parquet::arrow::async_reader::ParquetObjectReader;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ use datafusion_common::{
tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},
};
use datafusion_execution::registry::SerializerRegistry;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::var_provider::is_system_variables;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
use parking_lot::RwLock;
use std::collections::hash_map::Entry;
use std::string::String;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1338,10 +1338,10 @@ mod tests {
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use std::collections::HashMap;
use std::ops::{Not, Rem};

Expand Down
5 changes: 3 additions & 2 deletions datafusion/core/src/test_util/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
use crate::datasource::object_store::ObjectStoreUrl;
use crate::datasource::physical_plan::{FileScanConfig, ParquetExec};
use crate::error::Result;
use crate::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use crate::logical_expr::execution_props::ExecutionProps;
use crate::logical_expr::simplify::SimplifyContext;
use crate::optimizer::simplify_expressions::ExprSimplifier;
use crate::physical_expr::create_physical_expr;
use crate::physical_expr::execution_props::ExecutionProps;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::metrics::MetricsSet;
use crate::physical_plan::ExecutionPlan;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/variable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

//! Variable provider for `@name` and `@@name` style runtime values.

pub use datafusion_physical_expr::var_provider::{VarProvider, VarType};
pub use datafusion_expr::var_provider::{VarProvider, VarType};
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOpt
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr,
ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion_physical_expr::var_provider::{VarProvider, VarType};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/parquet/page_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ use datafusion::physical_plan::metrics::MetricValue;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{ScalarValue, Statistics, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{col, lit, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;

use futures::StreamExt;
use object_store::path::Path;
Expand Down
26 changes: 14 additions & 12 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::{ArrayRef, Int32Array};
use chrono::{DateTime, TimeZone, Utc};
use datafusion::common::DFSchema;
use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*};
use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_optimizer::simplify_expressions::{
ExprSimplifier, SimplifyExpressions, SimplifyInfo,
};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;

Expand All @@ -42,7 +41,7 @@ use std::sync::Arc;
/// objects or from some other implementation
struct MyInfo {
/// The input schema
schema: DFSchema,
schema: DFSchemaRef,

/// Execution specific details needed for constant evaluation such
/// as the current time for `now()` and [VariableProviders]
Expand All @@ -51,24 +50,27 @@ struct MyInfo {

impl SimplifyInfo for MyInfo {
fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean))
Ok(matches!(
expr.get_type(self.schema.as_ref())?,
DataType::Boolean
))
}

fn nullable(&self, expr: &Expr) -> Result<bool> {
expr.nullable(&self.schema)
expr.nullable(self.schema.as_ref())
}

fn execution_props(&self) -> &ExecutionProps {
&self.execution_props
}

fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
expr.get_type(&self.schema)
expr.get_type(self.schema.as_ref())
}
}

impl From<DFSchema> for MyInfo {
fn from(schema: DFSchema) -> Self {
impl From<DFSchemaRef> for MyInfo {
fn from(schema: DFSchemaRef) -> Self {
Self {
schema,
execution_props: ExecutionProps::new(),
Expand All @@ -81,13 +83,13 @@ impl From<DFSchema> for MyInfo {
/// a: Int32 (possibly with nulls)
/// b: Int32
/// s: Utf8
fn schema() -> DFSchema {
fn schema() -> DFSchemaRef {
Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("s", DataType::Utf8, false),
])
.try_into()
.to_dfschema_ref()
.unwrap()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
Expand All @@ -26,10 +28,13 @@ use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
plan_err, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};

use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
Expand Down Expand Up @@ -514,6 +519,97 @@ async fn deregister_udf() -> Result<()> {
Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
}

impl CastToI64UDF {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for CastToI64UDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"cast_to_i64"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}
// Wrap with Expr::Cast() to Int64
fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// Note that Expr::cast_to requires an ExprSchema but simplify gets a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't figure out how to do cast_to but I think this way is OK too.

// SimplifyInfo so we have to replicate some of the casting logic here.
let source_type = info.get_data_type(&args[0])?;
if source_type == DataType::Int64 {
Ok(ExprSimplifyResult::Original(args))
} else {
// DataFusion should have ensured the function is called with just a
// single argument
assert_eq!(args.len(), 1);
let e = args.pop().unwrap();
Ok(ExprSimplifyResult::Simplified(Expr::Cast(
datafusion_expr::Cast {
expr: Box::new(e),
data_type: DataType::Int64,
},
)))
}
}
// Casting should be done in `simplify`, so we just return the first argument
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment is a bit outdated

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
assert_eq!(args.len(), 1);
Ok(args.first().unwrap().clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add ExprSimplifyResult::Replace(Expr) to cover this case, and eliminate UDF call when expression is simplified?

Copy link
Contributor

@milenkovicm milenkovicm Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To put some context to my comment, let's say if we define function f(INT, INT) = $1 + $2 we can eliminate UDF call with Alias($1 + $2, "f(a,b)") and get UDF free plan, which would be easier to distribute across ballista cluster

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is an excellent idea -- I did so in fea82cb

}
}

#[tokio::test]
async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for this test / example -- it makes seeing how the API would work really clear. 👏

let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))],
)?;

ctx.register_batch("t", batch)?;

let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new());
ctx.register_udf(cast_to_i64_udf);

let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?;

assert_batches_eq!(
&[
"+------------------+",
"| cast_to_i64(t.x) |",
"+------------------+",
"| 1 |",
"| 2 |",
"| 3 |",
"+------------------+"
],
&result
);

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ ahash = { version = "0.8", default-features = false, features = [
] }
arrow = { workspace = true }
arrow-array = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
paste = "^1.0"
sqlparser = { workspace = true }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ use std::sync::Arc;
/// Holds per-query execution properties and data (such as statement
/// starting timestamps).
///
/// An [`ExecutionProps`] is created each time a [`LogicalPlan`] is
/// An [`ExecutionProps`] is created each time a `LogicalPlan` is
/// prepared for execution (optimized). If the same plan is optimized
/// multiple times, a new `ExecutionProps` is created each time.
///
/// It is important that this structure be cheap to create as it is
/// done so during predicate pruning and expression simplification
///
/// [`LogicalPlan`]: datafusion_expr::LogicalPlan
#[derive(Clone, Debug)]
pub struct ExecutionProps {
pub query_execution_start_time: DateTime<Utc>,
Expand Down
Loading
Loading