-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a ScalarUDFImpl::simplfy()
API, move SimplifyInfo
et al to datafusion_expr
#9304
Changes from 24 commits
b787dfb
7259275
4d98121
bacc966
3199bca
63648be
83fc9d8
c73126a
dd99362
0b66ed3
0798274
5fdf177
ab66a19
7c7b654
cbefb3c
7718251
5cdae92
4886ba5
ed6a04b
f6848d8
a8541ff
18f8371
bfb54a0
33aa7ff
24adcbf
8cab80b
550cbc4
fea82cb
4e9eb70
fdec54c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,9 @@ | |
// under the License. | ||
|
||
use arrow::compute::kernels::numeric::add; | ||
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array}; | ||
use arrow_array::{ | ||
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, | ||
}; | ||
use arrow_schema::DataType::Float64; | ||
use arrow_schema::{DataType, Field, Schema}; | ||
use datafusion::prelude::*; | ||
|
@@ -26,10 +28,13 @@ use datafusion_common::{ | |
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, | ||
plan_err, ExprSchema, Result, ScalarValue, | ||
}; | ||
use datafusion_expr::simplify::ExprSimplifyResult; | ||
use datafusion_expr::simplify::SimplifyInfo; | ||
use datafusion_expr::{ | ||
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, | ||
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, | ||
}; | ||
|
||
use rand::{thread_rng, Rng}; | ||
use std::any::Any; | ||
use std::iter; | ||
|
@@ -514,6 +519,97 @@ async fn deregister_udf() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct CastToI64UDF { | ||
signature: Signature, | ||
} | ||
|
||
impl CastToI64UDF { | ||
fn new() -> Self { | ||
Self { | ||
signature: Signature::any(1, Volatility::Immutable), | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for CastToI64UDF { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
fn name(&self) -> &str { | ||
"cast_to_i64" | ||
} | ||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
fn return_type(&self, _args: &[DataType]) -> Result<DataType> { | ||
Ok(DataType::Int64) | ||
} | ||
// Wrap with Expr::Cast() to Int64 | ||
fn simplify( | ||
&self, | ||
mut args: Vec<Expr>, | ||
info: &dyn SimplifyInfo, | ||
) -> Result<ExprSimplifyResult> { | ||
// Note that Expr::cast_to requires an ExprSchema but simplify gets a | ||
// SimplifyInfo so we have to replicate some of the casting logic here. | ||
let source_type = info.get_data_type(&args[0])?; | ||
if source_type == DataType::Int64 { | ||
Ok(ExprSimplifyResult::Original(args)) | ||
} else { | ||
// DataFusion should have ensured the function is called with just a | ||
// single argument | ||
assert_eq!(args.len(), 1); | ||
let e = args.pop().unwrap(); | ||
Ok(ExprSimplifyResult::Simplified(Expr::Cast( | ||
datafusion_expr::Cast { | ||
expr: Box::new(e), | ||
data_type: DataType::Int64, | ||
}, | ||
))) | ||
} | ||
} | ||
// Casting should be done in `simplify`, so we just return the first argument | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment is a bit outdated |
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
assert_eq!(args.len(), 1); | ||
Ok(args.first().unwrap().clone()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To put some context to my comment, let's say if we define function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is an excellent idea -- I did so in fea82cb |
||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_user_defined_functions_cast_to_i64() -> Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you so much for this test / example -- it makes seeing how the API would work really clear. 👏 |
||
let ctx = SessionContext::new(); | ||
|
||
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)])); | ||
|
||
let batch = RecordBatch::try_new( | ||
schema, | ||
vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))], | ||
)?; | ||
|
||
ctx.register_batch("t", batch)?; | ||
|
||
let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new()); | ||
ctx.register_udf(cast_to_i64_udf); | ||
|
||
let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?; | ||
|
||
assert_batches_eq!( | ||
&[ | ||
"+------------------+", | ||
"| cast_to_i64(t.x) |", | ||
"+------------------+", | ||
"| 1 |", | ||
"| 2 |", | ||
"| 3 |", | ||
"+------------------+" | ||
], | ||
&result | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct TakeUDF { | ||
signature: Signature, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't figure out how to do
cast_to
but I think this way is OK too.