Skip to content

Commit 92d9274

Browse files
yyy1000alambWeijun-HJefffrey
authored
Support compute return types from argument values (not just their DataTypes) (apache#8985)
* ScalarValue return types from argument values * change file name * try using ?Sized * use Ok * move method default impl outside trait * Use type trait for ExprSchemable * fix nit * Proposed Return Type from Expr suggestions (#1) * Improve return_type_from_args * Rework example * Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs --------- Co-authored-by: Junhao Liu <junhaoliu2023@gmail.com> * Apply suggestions from code review Co-authored-by: Alex Huang <huangweijun1001@gmail.com> * Fix tests + clippy * rework types to use dyn trait * fmt * docs * Apply suggestions from code review Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> * Add docs explaining what happens when both `return_type` and `return_type_from_exprs` are called * clippy * fix doc -- comedy of errors --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> Co-authored-by: Alex Huang <huangweijun1001@gmail.com> Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent 85be1bc commit 92d9274

File tree

6 files changed

+245
-53
lines changed

6 files changed

+245
-53
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema};
2222
use datafusion::prelude::*;
2323
use datafusion::{execution::registry::FunctionRegistry, test_util};
2424
use datafusion_common::cast::as_float64_array;
25-
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
25+
use datafusion_common::{
26+
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
27+
plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
28+
};
2629
use datafusion_expr::{
27-
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
28-
ScalarUDFImpl, Signature, Volatility,
30+
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
31+
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
2932
};
3033
use rand::{thread_rng, Rng};
34+
use std::any::Any;
3135
use std::iter;
3236
use std::sync::Arc;
3337

@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
494498
Ok(())
495499
}
496500

501+
#[derive(Debug)]
502+
struct TakeUDF {
503+
signature: Signature,
504+
}
505+
506+
impl TakeUDF {
507+
fn new() -> Self {
508+
Self {
509+
signature: Signature::any(3, Volatility::Immutable),
510+
}
511+
}
512+
}
513+
514+
/// Implement a ScalarUDFImpl whose return type is a function of the input values
515+
impl ScalarUDFImpl for TakeUDF {
516+
fn as_any(&self) -> &dyn Any {
517+
self
518+
}
519+
fn name(&self) -> &str {
520+
"take"
521+
}
522+
fn signature(&self) -> &Signature {
523+
&self.signature
524+
}
525+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
526+
not_impl_err!("Not called because the return_type_from_exprs is implemented")
527+
}
528+
529+
/// This function returns the type of the first or second argument based on
530+
/// the third argument:
531+
///
532+
/// 1. If the third argument is '0', return the type of the first argument
533+
/// 2. If the third argument is '1', return the type of the second argument
534+
fn return_type_from_exprs(
535+
&self,
536+
arg_exprs: &[Expr],
537+
schema: &dyn ExprSchema,
538+
) -> Result<DataType> {
539+
if arg_exprs.len() != 3 {
540+
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
541+
}
542+
543+
let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
544+
arg_exprs.get(2)
545+
{
546+
if *idx == 0 || *idx == 1 {
547+
*idx as usize
548+
} else {
549+
return plan_err!("The third argument must be 0 or 1, got: {idx}");
550+
}
551+
} else {
552+
return plan_err!(
553+
"The third argument must be a literal of type int64, but got {:?}",
554+
arg_exprs.get(2)
555+
);
556+
};
557+
558+
arg_exprs.get(take_idx).unwrap().get_type(schema)
559+
}
560+
561+
// The actual implementation
562+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
563+
let take_idx = match &args[2] {
564+
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
565+
_ => unreachable!(),
566+
};
567+
match &args[take_idx] {
568+
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())),
569+
ColumnarValue::Scalar(_) => unimplemented!(),
570+
}
571+
}
572+
}
573+
574+
#[tokio::test]
575+
async fn verify_udf_return_type() -> Result<()> {
576+
// Create a new ScalarUDF from the implementation
577+
let take = ScalarUDF::from(TakeUDF::new());
578+
579+
// SELECT
580+
// take(smallint_col, double_col, 0) as take0,
581+
// take(smallint_col, double_col, 1) as take1
582+
// FROM alltypes_plain;
583+
let exprs = vec![
584+
take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
585+
.alias("take0"),
586+
take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
587+
.alias("take1"),
588+
];
589+
590+
let ctx = SessionContext::new();
591+
register_alltypes_parquet(&ctx).await?;
592+
593+
let df = ctx.table("alltypes_plain").await?.select(exprs)?;
594+
595+
let schema = df.schema();
596+
597+
// The output schema should be
598+
// * type of column smallint_col (int32)
599+
// * type of column double_col (float64)
600+
assert_eq!(schema.field(0).data_type(), &DataType::Int32);
601+
assert_eq!(schema.field(1).data_type(), &DataType::Float64);
602+
603+
let expected = [
604+
"+-------+-------+",
605+
"| take0 | take1 |",
606+
"+-------+-------+",
607+
"| 0 | 0.0 |",
608+
"| 0 | 0.0 |",
609+
"| 0 | 0.0 |",
610+
"| 0 | 0.0 |",
611+
"| 1 | 10.1 |",
612+
"| 1 | 10.1 |",
613+
"| 1 | 10.1 |",
614+
"| 1 | 10.1 |",
615+
"+-------+-------+",
616+
];
617+
assert_batches_sorted_eq!(&expected, &df.collect().await?);
618+
619+
Ok(())
620+
}
621+
497622
fn create_udf_context() -> SessionContext {
498623
let ctx = SessionContext::new();
499624
// register a custom UDF
@@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> {
531656
Ok(())
532657
}
533658

659+
async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> {
660+
let testdata = datafusion::test_util::parquet_test_data();
661+
ctx.register_parquet(
662+
"alltypes_plain",
663+
&format!("{testdata}/alltypes_plain.parquet"),
664+
ParquetReadOptions::default(),
665+
)
666+
.await?;
667+
Ok(())
668+
}
669+
534670
/// Execute SQL and return results as a RecordBatch
535671
async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
536672
ctx.sql(sql).await?.collect().await

datafusion/expr/src/expr_schema.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,37 @@ use crate::{utils, LogicalPlan, Projection, Subquery};
2828
use arrow::compute::can_cast_types;
2929
use arrow::datatypes::{DataType, Field};
3030
use datafusion_common::{
31-
internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema,
32-
DataFusionError, ExprSchema, Result,
31+
internal_err, plan_datafusion_err, plan_err, Column, DFField, DataFusionError,
32+
ExprSchema, Result,
3333
};
3434
use std::collections::HashMap;
3535
use std::sync::Arc;
3636

3737
/// trait to allow expr to typable with respect to a schema
3838
pub trait ExprSchemable {
3939
/// given a schema, return the type of the expr
40-
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
40+
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
4141

4242
/// given a schema, return the nullability of the expr
43-
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
43+
fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
4444

4545
/// given a schema, return the expr's optional metadata
46-
fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>>;
46+
fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
4747

4848
/// convert to a field with respect to a schema
49-
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
49+
fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField>;
5050

5151
/// cast to a type with respect to a schema
52-
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
52+
fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
5353
}
5454

5555
impl ExprSchemable for Expr {
5656
/// Returns the [arrow::datatypes::DataType] of the expression
5757
/// based on [ExprSchema]
5858
///
59-
/// Note: [DFSchema] implements [ExprSchema].
59+
/// Note: [`DFSchema`] implements [ExprSchema].
60+
///
61+
/// [`DFSchema`]: datafusion_common::DFSchema
6062
///
6163
/// # Examples
6264
///
@@ -90,7 +92,7 @@ impl ExprSchemable for Expr {
9092
/// expression refers to a column that does not exist in the
9193
/// schema, or when the expression is incorrectly typed
9294
/// (e.g. `[utf8] + [bool]`).
93-
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
95+
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
9496
match self {
9597
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
9698
Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
@@ -136,7 +138,7 @@ impl ExprSchemable for Expr {
136138
fun.return_type(&arg_data_types)
137139
}
138140
ScalarFunctionDefinition::UDF(fun) => {
139-
Ok(fun.return_type(&arg_data_types)?)
141+
Ok(fun.return_type_from_exprs(args, schema)?)
140142
}
141143
ScalarFunctionDefinition::Name(_) => {
142144
internal_err!("Function `Expr` with name should be resolved.")
@@ -213,14 +215,16 @@ impl ExprSchemable for Expr {
213215

214216
/// Returns the nullability of the expression based on [ExprSchema].
215217
///
216-
/// Note: [DFSchema] implements [ExprSchema].
218+
/// Note: [`DFSchema`] implements [ExprSchema].
219+
///
220+
/// [`DFSchema`]: datafusion_common::DFSchema
217221
///
218222
/// # Errors
219223
///
220224
/// This function errors when it is not possible to compute its
221225
/// nullability. This happens when the expression refers to a
222226
/// column that does not exist in the schema.
223-
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
227+
fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
224228
match self {
225229
Expr::Alias(Alias { expr, .. })
226230
| Expr::Not(expr)
@@ -327,7 +331,7 @@ impl ExprSchemable for Expr {
327331
}
328332
}
329333

330-
fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>> {
334+
fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
331335
match self {
332336
Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
333337
Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
@@ -339,7 +343,7 @@ impl ExprSchemable for Expr {
339343
///
340344
/// So for example, a projected expression `col(c1) + col(c2)` is
341345
/// placed in an output field **named** col("c1 + c2")
342-
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
346+
fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField> {
343347
match self {
344348
Expr::Column(c) => Ok(DFField::new(
345349
c.relation.clone(),
@@ -370,7 +374,7 @@ impl ExprSchemable for Expr {
370374
///
371375
/// This function errors when it is impossible to cast the
372376
/// expression to the target [arrow::datatypes::DataType].
373-
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> {
377+
fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
374378
let this_type = self.get_type(schema)?;
375379
if this_type == *cast_to_type {
376380
return Ok(self);
@@ -394,10 +398,10 @@ impl ExprSchemable for Expr {
394398
}
395399

396400
/// return the schema [`Field`] for the type referenced by `get_indexed_field`
397-
fn field_for_index<S: ExprSchema>(
401+
fn field_for_index(
398402
expr: &Expr,
399403
field: &GetFieldAccess,
400-
schema: &S,
404+
schema: &dyn ExprSchema,
401405
) -> Result<Field> {
402406
let expr_dt = expr.get_type(schema)?;
403407
match field {
@@ -457,7 +461,7 @@ mod tests {
457461
use super::*;
458462
use crate::{col, lit};
459463
use arrow::datatypes::{DataType, Fields};
460-
use datafusion_common::{Column, ScalarValue, TableReference};
464+
use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
461465

462466
macro_rules! test_is_expr_nullable {
463467
($EXPR_TYPE:ident) => {{

0 commit comments

Comments
 (0)