Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Substrait's VirtualTables #10531

Merged
merged 10 commits into from
May 28, 2024
191 changes: 157 additions & 34 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
// under the License.

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion::arrow::datatypes::{
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};

use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr, LogicalPlan,
Operator, ScalarUDF,
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
LogicalPlan, Operator, ScalarUDF, Values,
};
use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand Down Expand Up @@ -58,7 +60,7 @@ use substrait::proto::{
rel::RelType,
set_rel,
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, Plan, Rel, Type,
AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
};
use substrait::proto::{FunctionArgument, SortField};

Expand Down Expand Up @@ -509,7 +511,51 @@ pub async fn from_substrait_rel(
_ => Ok(t),
}
}
_ => not_impl_err!("Only NamedTable reads are supported"),
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Virtual Table")
})?;

let schema = from_substrait_named_struct(base_schema)?;

if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
}));
}

let values = vt
.values
.iter()
.map(|row| {
let mut name_idx = 0;
let lits = row
.fields
.iter()
.map(|lit| {
name_idx += 1; // top-level names are provided through schema
Ok(Expr::Literal(from_substrait_literal(
lit,
&base_schema.names,
&mut name_idx,
)?))
})
.collect::<Result<_>>()?;
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
Ok(lits)
})
.collect::<Result<_>>()?;

Ok(LogicalPlan::Values(Values { schema, values }))
}
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
},
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
Ok(set_op) => match set_op {
Expand Down Expand Up @@ -948,7 +994,7 @@ pub async fn from_substrait_rex(
}
}
Some(RexType::Literal(lit)) => {
let scalar_value = from_substrait_literal(lit)?;
let scalar_value = from_substrait_literal_without_names(lit)?;
Ok(Arc::new(Expr::Literal(scalar_value)))
}
Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
Expand All @@ -964,9 +1010,9 @@ pub async fn from_substrait_rex(
.as_ref()
.clone(),
),
from_substrait_type(output_type)?,
from_substrait_type_without_names(output_type)?,
)))),
None => substrait_err!("Cast experssion without output type is not allowed"),
None => substrait_err!("Cast expression without output type is not allowed"),
},
Some(RexType::WindowFunction(window)) => {
let fun = match extensions.get(&window.function_reference) {
Expand Down Expand Up @@ -1062,7 +1108,15 @@ pub async fn from_substrait_rex(
}
}

pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result<DataType> {
from_substrait_type(dt, &[], &mut 0)
}

fn from_substrait_type(
dt: &Type,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
Expand Down Expand Up @@ -1142,7 +1196,7 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
substrait_datafusion_err!("List type must have inner type")
})?;
let field = Arc::new(Field::new_list_field(
from_substrait_type(inner_type)?,
from_substrait_type(inner_type, dfs_names, name_idx)?,
is_substrait_type_nullable(inner_type)?,
));
match list.type_variation_reference {
Expand Down Expand Up @@ -1182,24 +1236,69 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
),
}
},
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
&format!("c{i}"),
from_substrait_type(f)?,
is_substrait_type_nullable(f)?,
);
fields.push(field);
}
Ok(DataType::Struct(fields.into()))
}
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type(
s, dfs_names, name_idx,
)?)),
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
}
}

fn from_substrait_struct_type(
s: &r#type::Struct,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<Fields> {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
next_struct_field_name(i, dfs_names, name_idx)?,
from_substrait_type(f, dfs_names, name_idx)?,
is_substrait_type_nullable(f)?,
);
fields.push(field);
}
Ok(fields.into())
}

fn next_struct_field_name(
i: usize,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<String> {
if dfs_names.is_empty() {
// If names are not given, create dummy names
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
Ok(format!("c{i}"))
} else {
let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
substrait_datafusion_err!("Named schema must contain names for all fields")
})?;
*name_idx += 1;
Ok(name)
}
}

fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result<DFSchemaRef> {
let mut name_idx = 0;
let fields = from_substrait_struct_type(
base_schema.r#struct.as_ref().ok_or_else(|| {
substrait_datafusion_err!("Named struct must contain a struct")
})?,
&base_schema.names,
&mut name_idx,
);
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
}

fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
fn is_nullable(nullability: i32) -> bool {
nullability != substrait::proto::r#type::Nullability::Required as i32
Expand Down Expand Up @@ -1277,7 +1376,15 @@ fn from_substrait_bound(
}
}

pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> Result<ScalarValue> {
from_substrait_literal(lit, &vec![], &mut 0)
}

fn from_substrait_literal(
lit: &Literal,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<ScalarValue> {
let scalar_value = match &lit.literal_type {
Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
Expand Down Expand Up @@ -1359,7 +1466,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
let elements = l
.values
.iter()
.map(from_substrait_literal)
.map(|el| from_substrait_literal(el, dfs_names, name_idx))
.collect::<Result<Vec<_>>>()?;
if elements.is_empty() {
return substrait_err!(
Expand All @@ -1381,7 +1488,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
}
}
Some(LiteralType::EmptyList(l)) => {
let element_type = from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
let element_type = from_substrait_type(
l.r#type.clone().unwrap().as_ref(),
dfs_names,
name_idx,
)?;
match lit.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => {
ScalarValue::List(ScalarValue::new_list(&[], &element_type))
Expand All @@ -1397,16 +1508,16 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
let sv = from_substrait_literal(field)?;
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
builder = builder.with_scalar(
Field::new(&format!("c{i}"), sv.data_type(), field.nullable),
sv,
);
let name = next_struct_field_name(i, dfs_names, name_idx)?;
let sv = from_substrait_literal(field, dfs_names, name_idx)?;
builder = builder
.with_scalar(Field::new(name, sv.data_type(), field.nullable), sv);
}
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
Some(LiteralType::Null(ntype)) => {
from_substrait_null(ntype, dfs_names, name_idx)?
}
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Expand Down Expand Up @@ -1461,7 +1572,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
Ok(scalar_value)
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
fn from_substrait_null(
null_type: &Type,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
Expand Down Expand Up @@ -1539,7 +1654,11 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
)),
r#type::Kind::List(l) => {
let field = Field::new_list_field(
from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
from_substrait_type(
l.r#type.clone().unwrap().as_ref(),
dfs_names,
name_idx,
)?,
true,
);
match l.type_variation_reference {
Expand All @@ -1554,6 +1673,10 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
),
}
}
r#type::Kind::Struct(s) => {
let fields = from_substrait_struct_type(s, dfs_names, name_idx)?;
Ok(ScalarStructBuilder::new_null(fields))
}
_ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"),
}
} else {
Expand Down
Loading