Skip to content

Commit

Permalink
rewrite struct, wraping it with a cast
Browse files Browse the repository at this point in the history
  • Loading branch information
gstvg committed Mar 25, 2024
1 parent 47f4b5a commit 98ff239
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 11 deletions.
2 changes: 1 addition & 1 deletion datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod getfield;
mod nullif;
mod nvl;
mod nvl2;
mod r#struct;
pub(crate) mod r#struct;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast);
Expand Down
59 changes: 56 additions & 3 deletions datafusion/functions/src/core/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

use arrow::array::{ArrayRef, StructArray};
use arrow::datatypes::{DataType, Field, Fields};
use datafusion_common::{exec_err, Result};
use datafusion_expr::ColumnarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::Transformed;
use datafusion_common::{exec_err, DFSchema, Result};
use datafusion_expr::expr::{Alias, ScalarFunction};
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{cast, Cast, ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -60,6 +64,9 @@ fn struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
.collect::<Result<Vec<ArrayRef>>>()?;
Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?))
}

const STRUCT_KEYWORD: &'static str = "struct";

#[derive(Debug)]
pub(super) struct StructFunc {
signature: Signature,
Expand All @@ -78,7 +85,7 @@ impl ScalarUDFImpl for StructFunc {
self
}
fn name(&self) -> &str {
"struct"
STRUCT_KEYWORD
}

fn signature(&self) -> &Signature {
Expand All @@ -99,6 +106,52 @@ impl ScalarUDFImpl for StructFunc {
}
}

pub(crate) struct StructRewriter {}

impl FunctionRewrite for StructRewriter {
fn name(&self) -> &str {
"StructRewriter"
}

fn rewrite(
&self,
expr: Expr,
schema: &DFSchema,
_: &ConfigOptions,
) -> Result<Transformed<Expr>> {
match &expr {
Expr::ScalarFunction(ScalarFunction{func_def, args}) => {
if func_def.name() == STRUCT_KEYWORD {

let fields = args.iter()
.enumerate()
.map(|(i, arg)| {
let name = match arg {
Expr::Alias(Alias { expr, relation: None, name }) => {
Ok(name.clone())
},
Expr::Alias(alias) if alias.relation.is_some() => {
exec_err!("struct field name must be unqualified: {arg}")
}
_ => {
Ok(format!("c{i}"))
}
}?;

Ok(Field::new(name, arg.get_type(schema)?, true))
})
.collect::<Result<_>>()?;

Ok(Transformed::yes(cast(expr, DataType::Struct(fields))))
} else {
Ok(Transformed::no(expr))
}
}
_ => Ok(Transformed::no(expr)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
//! feature is enabled.
//!
//! [`ScalarUDF`]: datafusion_expr::ScalarUDF
use std::sync::Arc;

use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
use log::debug;
Expand Down Expand Up @@ -153,6 +155,8 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
.chain(crypto::functions())
.chain(string::functions());

registry.register_function_rewrite(Arc::new(core::r#struct::StructRewriter{}))?;

all_functions.try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
if let Some(existing_udf) = existing_udf {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Struct { values, fields } => {
self.parse_struct(values, fields, schema, planner_context)
}
SQLExpr::Named { expr, name } => {
Ok(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?
.alias(name.value))
}
SQLExpr::Position { expr, r#in } => {
self.sql_position_to_expr(*expr, *r#in, schema, planner_context)
}
Expand Down
14 changes: 7 additions & 7 deletions datafusion/sqllogictest/test_files/struct.slt
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ select struct(1, 3.14, 'e');

# struct scalar function with columns #1
query ?
select struct(a, b, c) from values;
select struct(a, b as b_name, c) from values;
----
{c0: 1, c1: 1.1, c2: a}
{c0: 2, c1: 2.2, c2: b}
{c0: 3, c1: 3.3, c2: c}
{c0: 1, b_name: 1.1, c2: a}
{c0: 2, b_name: 2.2, c2: b}
{c0: 3, b_name: 3.3, c2: c}

# explain struct scalar function with columns #1
query TT
explain select struct(a, b, c) from values;
explain select struct(a, b as b_name, c) from values;
----
logical_plan
Projection: struct(values.a, values.b, values.c)
Projection: CAST(struct(values.a, values.b AS b_name, values.c) AS Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b_name", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]))
--TableScan: values projection=[a, b, c]
physical_plan
ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)]
ProjectionExec: expr=[CAST(struct(a@0, b@1, c@2) AS Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b_name", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])) as struct(values.a,b_name,values.c)]
--MemoryExec: partitions=1, partition_sizes=[1]

statement ok
Expand Down

0 comments on commit 98ff239

Please sign in to comment.