Skip to content

Commit 59b3958

Browse files
alambyyy1000
andauthored
Proposed Return Type from Expr suggestions (#1)
* Improve return_type_from_args * Rework example * Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs --------- Co-authored-by: Junhao Liu <junhaoliu2023@gmail.com>
1 parent 5772d9f commit 59b3958

File tree

6 files changed

+193
-198
lines changed

6 files changed

+193
-198
lines changed

datafusion-examples/examples/return_types_udf.rs

Lines changed: 0 additions & 150 deletions
This file was deleted.

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema};
2222
use datafusion::prelude::*;
2323
use datafusion::{execution::registry::FunctionRegistry, test_util};
2424
use datafusion_common::cast::as_float64_array;
25-
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
25+
use datafusion_common::{
26+
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
27+
plan_err, DFSchema, DataFusionError, Result, ScalarValue,
28+
};
2629
use datafusion_expr::{
27-
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
28-
ScalarUDFImpl, Signature, Volatility,
30+
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
31+
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
2932
};
3033
use rand::{thread_rng, Rng};
34+
use std::any::Any;
3135
use std::iter;
3236
use std::sync::Arc;
3337

@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
494498
Ok(())
495499
}
496500

501+
#[derive(Debug)]
502+
struct TakeUDF {
503+
signature: Signature,
504+
}
505+
506+
impl TakeUDF {
507+
fn new() -> Self {
508+
Self {
509+
signature: Signature::any(3, Volatility::Immutable),
510+
}
511+
}
512+
}
513+
514+
/// Implement a ScalarUDFImpl whose return type is a function of the input values
515+
impl ScalarUDFImpl for TakeUDF {
516+
fn as_any(&self) -> &dyn Any {
517+
self
518+
}
519+
fn name(&self) -> &str {
520+
"take"
521+
}
522+
fn signature(&self) -> &Signature {
523+
&self.signature
524+
}
525+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
526+
not_impl_err!("Not called because the return_type_from_exprs is implemented")
527+
}
528+
529+
/// Thus function returns the type of the first or second argument based on
530+
/// the third argument:
531+
///
532+
/// 1. If the third argument is '0', return the type of the first argument
533+
/// 2. If the third argument is '1', return the type of the second argument
534+
fn return_type_from_exprs(
535+
&self,
536+
arg_exprs: &[Expr],
537+
schema: &DFSchema,
538+
) -> Result<DataType> {
539+
if arg_exprs.len() != 3 {
540+
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
541+
}
542+
543+
let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
544+
arg_exprs.get(2)
545+
{
546+
if *idx == 0 || *idx == 1 {
547+
*idx as usize
548+
} else {
549+
return plan_err!("The third argument must be 0 or 1, got: {idx}");
550+
}
551+
} else {
552+
return plan_err!(
553+
"The third argument must be a literal of type int64, but got {:?}",
554+
arg_exprs.get(2)
555+
);
556+
};
557+
558+
arg_exprs.get(take_idx).unwrap().get_type(schema)
559+
}
560+
561+
// The actual implementation rethr
562+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
563+
let take_idx = match &args[2] {
564+
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
565+
_ => unreachable!(),
566+
};
567+
match &args[take_idx] {
568+
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())),
569+
ColumnarValue::Scalar(_) => unimplemented!(),
570+
}
571+
}
572+
}
573+
574+
#[tokio::test]
575+
async fn verify_udf_return_type() -> Result<()> {
576+
// Create a new ScalarUDF from the implementation
577+
let take = ScalarUDF::from(TakeUDF::new());
578+
579+
// SELECT
580+
// take(smallint_col, double_col, 0) as take0,
581+
// take(smallint_col, double_col, 1) as take1
582+
// FROM alltypes_plain;
583+
let exprs = vec![
584+
take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
585+
.alias("take0"),
586+
take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
587+
.alias("take1"),
588+
];
589+
590+
let ctx = SessionContext::new();
591+
register_alltypes_parquet(&ctx).await?;
592+
593+
let df = ctx.table("alltypes_plain").await?.select(exprs)?;
594+
595+
let schema = df.schema();
596+
597+
// The output schema should be
598+
// * type of column smallint_col (float64)
599+
// * type of column double_col (float32)
600+
assert_eq!(schema.field(0).data_type(), &DataType::Int32);
601+
assert_eq!(schema.field(1).data_type(), &DataType::Float64);
602+
603+
let expected = [
604+
"+-------+-------+",
605+
"| take0 | take1 |",
606+
"+-------+-------+",
607+
"| 0 | 0.0 |",
608+
"| 0 | 0.0 |",
609+
"| 0 | 0.0 |",
610+
"| 0 | 0.0 |",
611+
"| 1 | 10.1 |",
612+
"| 1 | 10.1 |",
613+
"| 1 | 10.1 |",
614+
"| 1 | 10.1 |",
615+
"+-------+-------+",
616+
];
617+
assert_batches_sorted_eq!(&expected, &df.collect().await?);
618+
619+
Ok(())
620+
}
621+
497622
fn create_udf_context() -> SessionContext {
498623
let ctx = SessionContext::new();
499624
// register a custom UDF
@@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> {
531656
Ok(())
532657
}
533658

659+
async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> {
660+
let testdata = datafusion::test_util::parquet_test_data();
661+
ctx.register_parquet(
662+
"alltypes_plain",
663+
&format!("{testdata}/alltypes_plain.parquet"),
664+
ParquetReadOptions::default(),
665+
)
666+
.await?;
667+
Ok(())
668+
}
669+
534670
/// Execute SQL and return results as a RecordBatch
535671
async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
536672
ctx.sql(sql).await?.collect().await

datafusion/expr/src/expr_schema.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ impl ExprSchemable<DFSchema> for Expr {
330330

331331
fn metadata(&self, schema: &DFSchema) -> Result<HashMap<String, String>> {
332332
match self {
333-
Expr::Column(c) => Ok(schema.metadata().clone()),
333+
Expr::Column(_) => Ok(schema.metadata().clone()),
334334
Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
335335
_ => Ok(HashMap::new()),
336336
}

0 commit comments

Comments
 (0)