Skip to content

Commit

Permalink
Support Substrait's VirtualTables (#10531)
Browse files Browse the repository at this point in the history
* Add support for Substrait VirtualTables

Adds support for Substrait's VirtualTables, ie. tables with data baked-in into the Substrait plan instead of being read from a source.

Adds conversion in both ways (Substrait -> DataFusion and DataFusion -> Substrait)
and a roundtrip test.

* fix clippy

* Add support for empty relations

* Fix consuming Structs inside Lists and Structs

Also adds roundtrip schema assertions for cases where possible

* Rename from_substrait_struct -> from_substrait_struct_type for clarity

* Add DataType::LargeList to to_substrait_named_struct

* cargo fmt --all

* Add validation that names list matches schema exactly

* Add a LargeList into VALUES test
  • Loading branch information
Blizzara authored May 28, 2024
1 parent 5a9712e commit 2762754
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 62 deletions.
191 changes: 157 additions & 34 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
// under the License.

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion::arrow::datatypes::{
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};

use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr, LogicalPlan,
Operator, ScalarUDF,
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
LogicalPlan, Operator, ScalarUDF, Values,
};
use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Expand Down Expand Up @@ -58,7 +60,7 @@ use substrait::proto::{
rel::RelType,
set_rel,
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, Plan, Rel, Type,
AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
};
use substrait::proto::{FunctionArgument, SortField};

Expand Down Expand Up @@ -509,7 +511,51 @@ pub async fn from_substrait_rel(
_ => Ok(t),
}
}
_ => not_impl_err!("Only NamedTable reads are supported"),
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Virtual Table")
})?;

let schema = from_substrait_named_struct(base_schema)?;

if vt.values.is_empty() {
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema,
}));
}

let values = vt
.values
.iter()
.map(|row| {
let mut name_idx = 0;
let lits = row
.fields
.iter()
.map(|lit| {
name_idx += 1; // top-level names are provided through schema
Ok(Expr::Literal(from_substrait_literal(
lit,
&base_schema.names,
&mut name_idx,
)?))
})
.collect::<Result<_>>()?;
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
Ok(lits)
})
.collect::<Result<_>>()?;

Ok(LogicalPlan::Values(Values { schema, values }))
}
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
},
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
Ok(set_op) => match set_op {
Expand Down Expand Up @@ -948,7 +994,7 @@ pub async fn from_substrait_rex(
}
}
Some(RexType::Literal(lit)) => {
let scalar_value = from_substrait_literal(lit)?;
let scalar_value = from_substrait_literal_without_names(lit)?;
Ok(Arc::new(Expr::Literal(scalar_value)))
}
Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
Expand All @@ -964,9 +1010,9 @@ pub async fn from_substrait_rex(
.as_ref()
.clone(),
),
from_substrait_type(output_type)?,
from_substrait_type_without_names(output_type)?,
)))),
None => substrait_err!("Cast experssion without output type is not allowed"),
None => substrait_err!("Cast expression without output type is not allowed"),
},
Some(RexType::WindowFunction(window)) => {
let fun = match extensions.get(&window.function_reference) {
Expand Down Expand Up @@ -1062,7 +1108,15 @@ pub async fn from_substrait_rex(
}
}

pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result<DataType> {
from_substrait_type(dt, &[], &mut 0)
}

fn from_substrait_type(
dt: &Type,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
Expand Down Expand Up @@ -1142,7 +1196,7 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
substrait_datafusion_err!("List type must have inner type")
})?;
let field = Arc::new(Field::new_list_field(
from_substrait_type(inner_type)?,
from_substrait_type(inner_type, dfs_names, name_idx)?,
is_substrait_type_nullable(inner_type)?,
));
match list.type_variation_reference {
Expand Down Expand Up @@ -1182,24 +1236,69 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
),
}
},
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
&format!("c{i}"),
from_substrait_type(f)?,
is_substrait_type_nullable(f)?,
);
fields.push(field);
}
Ok(DataType::Struct(fields.into()))
}
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type(
s, dfs_names, name_idx,
)?)),
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
}
}

fn from_substrait_struct_type(
s: &r#type::Struct,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<Fields> {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
next_struct_field_name(i, dfs_names, name_idx)?,
from_substrait_type(f, dfs_names, name_idx)?,
is_substrait_type_nullable(f)?,
);
fields.push(field);
}
Ok(fields.into())
}

fn next_struct_field_name(
i: usize,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<String> {
if dfs_names.is_empty() {
// If names are not given, create dummy names
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
Ok(format!("c{i}"))
} else {
let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
substrait_datafusion_err!("Named schema must contain names for all fields")
})?;
*name_idx += 1;
Ok(name)
}
}

fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result<DFSchemaRef> {
let mut name_idx = 0;
let fields = from_substrait_struct_type(
base_schema.r#struct.as_ref().ok_or_else(|| {
substrait_datafusion_err!("Named struct must contain a struct")
})?,
&base_schema.names,
&mut name_idx,
);
if name_idx != base_schema.names.len() {
return substrait_err!(
"Names list must match exactly to nested schema, but found {} uses for {} names",
name_idx,
base_schema.names.len()
);
}
Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
}

fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
fn is_nullable(nullability: i32) -> bool {
nullability != substrait::proto::r#type::Nullability::Required as i32
Expand Down Expand Up @@ -1277,7 +1376,15 @@ fn from_substrait_bound(
}
}

pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> Result<ScalarValue> {
from_substrait_literal(lit, &vec![], &mut 0)
}

fn from_substrait_literal(
lit: &Literal,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<ScalarValue> {
let scalar_value = match &lit.literal_type {
Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
Expand Down Expand Up @@ -1359,7 +1466,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
let elements = l
.values
.iter()
.map(from_substrait_literal)
.map(|el| from_substrait_literal(el, dfs_names, name_idx))
.collect::<Result<Vec<_>>>()?;
if elements.is_empty() {
return substrait_err!(
Expand All @@ -1381,7 +1488,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
}
}
Some(LiteralType::EmptyList(l)) => {
let element_type = from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
let element_type = from_substrait_type(
l.r#type.clone().unwrap().as_ref(),
dfs_names,
name_idx,
)?;
match lit.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => {
ScalarValue::List(ScalarValue::new_list(&[], &element_type))
Expand All @@ -1397,16 +1508,16 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
let sv = from_substrait_literal(field)?;
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
builder = builder.with_scalar(
Field::new(&format!("c{i}"), sv.data_type(), field.nullable),
sv,
);
let name = next_struct_field_name(i, dfs_names, name_idx)?;
let sv = from_substrait_literal(field, dfs_names, name_idx)?;
builder = builder
.with_scalar(Field::new(name, sv.data_type(), field.nullable), sv);
}
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
Some(LiteralType::Null(ntype)) => {
from_substrait_null(ntype, dfs_names, name_idx)?
}
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Expand Down Expand Up @@ -1461,7 +1572,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
Ok(scalar_value)
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
fn from_substrait_null(
null_type: &Type,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
Expand Down Expand Up @@ -1539,7 +1654,11 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
)),
r#type::Kind::List(l) => {
let field = Field::new_list_field(
from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
from_substrait_type(
l.r#type.clone().unwrap().as_ref(),
dfs_names,
name_idx,
)?,
true,
);
match l.type_variation_reference {
Expand All @@ -1554,6 +1673,10 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
),
}
}
r#type::Kind::Struct(s) => {
let fields = from_substrait_struct_type(s, dfs_names, name_idx)?;
Ok(ScalarStructBuilder::new_null(fields))
}
_ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"),
}
} else {
Expand Down
Loading

0 comments on commit 2762754

Please sign in to comment.