Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
// under the License.

use crate::logical_plan::consumer::{from_substrait_func_args, SubstraitConsumer};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::Result;
use datafusion::common::{
not_impl_err, plan_err, substrait_err, DFSchema, DataFusionError, ScalarValue,
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{expr, Between, BinaryExpr, Expr, Like, Operator};
use std::sync::Arc;
use std::vec::Drain;
use substrait::proto::expression::ScalarFunction;

Expand All @@ -42,6 +44,33 @@ pub async fn from_scalar_function(
};

let fn_name = substrait_fun_name(fn_signature);
if fn_name == "outer_reference" {
let arg = f.arguments.first().ok_or_else(|| {
DataFusionError::Substrait(
"outer_reference function requires at least one argument".to_string(),
)
})?;

let col_name = match &arg.arg_type {
Some(substrait::proto::function_argument::ArgType::Value(e)) => {
match &e.rex_type {
Some(substrait::proto::expression::RexType::Literal(
substrait::proto::expression::Literal {
literal_type: Some(substrait::proto::expression::literal::LiteralType::String(s)),
..
},
)) => s.clone(),
_ => return substrait_err!("outer_reference argument must be a string literal"),
}
}
_ => return substrait_err!("outer_reference argument must be a value"),
};

return Ok(Expr::OuterReferenceColumn(
Arc::new(Field::new("placeholder", DataType::Null, true)),
datafusion::common::Column::from_qualified_name(col_name),
));
}
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;

let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
Expand Down
33 changes: 29 additions & 4 deletions datafusion/substrait/src/logical_plan/producer/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub use if_then::*;
pub use literal::*;
pub use scalar_function::*;
pub use singular_or_list::*;
pub use subquery::*;
pub use subquery::{from_exists, from_in_subquery};
pub use window_function::*;

use crate::logical_plan::producer::utils::flatten_names;
Expand Down Expand Up @@ -139,7 +139,7 @@ pub fn to_substrait_rex(
}
Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema),
Expr::InList(expr) => producer.handle_in_list(expr, schema),
Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::Exists(expr) => producer.handle_exists(expr),
Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema),
Expr::ScalarSubquery(expr) => {
not_impl_err!("Cannot convert {expr:?} to Substrait")
Expand All @@ -148,8 +148,33 @@ pub fn to_substrait_rex(
Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::OuterReferenceColumn(_, _) => {
not_impl_err!("Cannot convert {expr:?} to Substrait")
Expr::OuterReferenceColumn(_, col) => {
let function_anchor =
producer.register_function("outer_reference".to_string());
Ok(Expression {
rex_type: Some(substrait::proto::expression::RexType::ScalarFunction(
substrait::proto::expression::ScalarFunction {
function_reference: function_anchor,
arguments: vec![substrait::proto::FunctionArgument {
arg_type: Some(substrait::proto::function_argument::ArgType::Value(
Expression {
rex_type: Some(substrait::proto::expression::RexType::Literal(
substrait::proto::expression::Literal {
literal_type: Some(
substrait::proto::expression::literal::LiteralType::String(
col.to_string()
)
),
..Default::default()
}
))
}
))
}],
..Default::default()
}
))
})
}
Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
}
Expand Down
45 changes: 45 additions & 0 deletions datafusion/substrait/src/logical_plan/producer/expr/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

use crate::logical_plan::producer::SubstraitProducer;
use datafusion::common::DFSchemaRef;
use datafusion::logical_expr::expr::Exists;
use datafusion::logical_expr::expr::InSubquery;
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::subquery::SetPredicate;
use substrait::proto::expression::{RexType, ScalarFunction};
use substrait::proto::function_argument::ArgType;
use substrait::proto::{Expression, FunctionArgument};
Expand Down Expand Up @@ -70,3 +72,46 @@ pub fn from_in_subquery(
Ok(substrait_subquery)
}
}

pub fn from_exists(
producer: &mut impl SubstraitProducer,
exists: &Exists,
) -> datafusion::common::Result<Expression> {
let Exists { subquery, negated } = exists;

let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?;

let substrait_subquery = Expression {
rex_type: Some(RexType::Subquery(Box::new(
substrait::proto::expression::Subquery {
subquery_type: Some(
substrait::proto::expression::subquery::SubqueryType::SetPredicate(
Box::new(SetPredicate {
predicate_op: 1,
tuples: Some(subquery_plan),
}),
),
),
},
))),
};

if *negated {
let function_anchor = producer.register_function("not".to_string());

#[allow(deprecated)]
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_subquery)),
}],
output_type: None,
args: vec![],
options: vec![],
})),
})
} else {
Ok(substrait_subquery)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
use crate::extensions::Extensions;
use crate::logical_plan::producer::{
from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr,
from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter,
from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal,
from_projection, from_repartition, from_scalar_function, from_sort,
from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists,
from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit,
from_literal, from_projection, from_repartition, from_scalar_function, from_sort,
from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union,
from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex,
};
Expand Down Expand Up @@ -359,6 +359,13 @@ pub trait SubstraitProducer: Send + Sync + Sized {
) -> datafusion::common::Result<Expression> {
from_in_subquery(self, in_subquery, schema)
}

fn handle_exists(
&mut self,
exists: &expr::Exists,
) -> datafusion::common::Result<Expression> {
from_exists(self, exists)
}
}

pub struct DefaultSubstraitProducer<'a> {
Expand Down