Skip to content

Commit

Permalink
feat: add nested struct supports (#1519)
Browse files Browse the repository at this point in the history
# Description
This PR adds nested struct supports.

# Related Issue(s)
- closes #1518 

# Documentation
  • Loading branch information
haruband authored Jul 15, 2023
1 parent ef8dd21 commit 4a4aaa9
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 23 deletions.
101 changes: 80 additions & 21 deletions rust/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use arrow_array::RecordBatch;
use arrow_array::{Array, ArrayRef, RecordBatch, StructArray};
use arrow_cast::{can_cast_types, cast_with_options, CastOptions};
use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -339,7 +339,7 @@ impl std::future::IntoFuture for WriteBuilder {
.or_else(|_| this.snapshot.arrow_schema())
.unwrap_or(schema.clone());

if !can_cast_batch(schema.as_ref(), table_schema.as_ref()) {
if !can_cast_batch(schema.fields(), table_schema.fields()) {
return Err(DeltaTableError::Generic(
"Updating table schema not yet implemented".to_string(),
));
Expand Down Expand Up @@ -473,19 +473,55 @@ impl std::future::IntoFuture for WriteBuilder {
}
}

fn can_cast_batch(from_schema: &ArrowSchema, to_schema: &ArrowSchema) -> bool {
if from_schema.fields.len() != to_schema.fields.len() {
fn can_cast_batch(from_fields: &Fields, to_fields: &Fields) -> bool {
if from_fields.len() != to_fields.len() {
return false;
}
from_schema.all_fields().iter().all(|f| {
if let Ok(target_field) = to_schema.field_with_name(f.name()) {
can_cast_types(f.data_type(), target_field.data_type())

from_fields.iter().all(|f| {
if let Some((_, target_field)) = to_fields.find(f.name()) {
if let (DataType::Struct(fields0), DataType::Struct(fields1)) =
(f.data_type(), target_field.data_type())
{
can_cast_batch(fields0, fields1)
} else {
can_cast_types(f.data_type(), target_field.data_type())
}
} else {
false
}
})
}

fn cast_record_batch_columns(
batch: &RecordBatch,
fields: &Fields,
cast_options: &CastOptions,
) -> Result<Vec<Arc<(dyn Array)>>, arrow_schema::ArrowError> {
fields
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();
if let (DataType::Struct(_), DataType::Struct(child_fields)) =
(col.data_type(), f.data_type())
{
let child_batch = RecordBatch::from(StructArray::from(col.into_data()));
let child_columns =
cast_record_batch_columns(&child_batch, child_fields, cast_options)?;
Ok(Arc::new(StructArray::new(
child_fields.clone(),
child_columns.clone(),
None,
)) as ArrayRef)
} else if !col.data_type().equals_datatype(f.data_type()) {
cast_with_options(col, f.data_type(), cast_options)
} else {
Ok(col.clone())
}
})
.collect::<Result<Vec<_>, _>>()
}

fn cast_record_batch(
batch: &RecordBatch,
target_schema: ArrowSchemaRef,
Expand All @@ -496,18 +532,7 @@ fn cast_record_batch(
..Default::default()
};

let columns = target_schema
.all_fields()
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();
if !col.data_type().equals_datatype(f.data_type()) {
cast_with_options(col, f.data_type(), &cast_options)
} else {
Ok(col.clone())
}
})
.collect::<Result<Vec<_>, _>>()?;
let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}

Expand All @@ -516,7 +541,10 @@ mod tests {
use super::*;
use crate::operations::DeltaOps;
use crate::writer::test_utils::datafusion::get_data;
use crate::writer::test_utils::{get_delta_schema, get_record_batch};
use crate::writer::test_utils::{
get_delta_schema, get_delta_schema_with_nested_struct, get_record_batch,
get_record_batch_with_nested_struct,
};
use arrow::datatypes::Field;
use arrow::datatypes::Schema as ArrowSchema;
use arrow_array::{Int32Array, StringArray, TimestampMicrosecondArray};
Expand Down Expand Up @@ -743,4 +771,35 @@ mod tests {
let table = DeltaOps(table).write(vec![batch.clone()]).await;
assert!(table.is_err())
}

#[tokio::test]
async fn test_nested_struct() {
let table_schema = get_delta_schema_with_nested_struct();
let batch = get_record_batch_with_nested_struct();

let table = DeltaOps::new_in_memory()
.create()
.with_columns(table_schema.get_fields().clone())
.await
.unwrap();
assert_eq!(table.version(), 0);

let table = DeltaOps(table)
.write(vec![batch.clone()])
.with_save_mode(SaveMode::Append)
.await
.unwrap();
assert_eq!(table.version(), 1);

let actual = get_data(&table).await;
let expected = DataType::Struct(Fields::from(vec![Field::new(
"count",
DataType::Int32,
true,
)]));
assert_eq!(
actual[0].column_by_name("nested").unwrap().data_type(),
&expected
);
}
}
122 changes: 120 additions & 2 deletions rust/src/writer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::collections::HashMap;
use std::sync::Arc;

use arrow::compute::take;
use arrow_array::{Int32Array, RecordBatch, StringArray, UInt32Array};
use arrow_array::{Int32Array, Int64Array, RecordBatch, StringArray, StructArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};

use crate::delta::DeltaTableMetaData;
use crate::operations::create::CreateBuilder;
use crate::schema::Schema;
use crate::{DeltaTable, DeltaTableBuilder, SchemaDataType, SchemaField};
use crate::{DeltaTable, DeltaTableBuilder, SchemaDataType, SchemaField, SchemaTypeStruct};

pub type TestResult = Result<(), Box<dyn std::error::Error + 'static>>;

Expand Down Expand Up @@ -164,6 +164,124 @@ pub fn get_delta_metadata(partition_cols: &[String]) -> DeltaTableMetaData {
)
}

pub fn get_record_batch_with_nested_struct() -> RecordBatch {
let nested_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"count",
DataType::Int64,
true,
)]));
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Utf8, true),
Field::new("value", DataType::Int32, true),
Field::new("modified", DataType::Utf8, true),
Field::new(
"nested",
DataType::Struct(nested_schema.fields().clone()),
true,
),
]));

let count_array = Int64Array::from(vec![
Some(1),
Some(2),
Some(3),
Some(4),
Some(5),
Some(6),
Some(7),
Some(8),
Some(9),
Some(10),
Some(11),
]);
let id_array = StringArray::from(vec![
Some("A"),
Some("B"),
None,
Some("B"),
Some("A"),
Some("A"),
None,
None,
Some("B"),
Some("A"),
Some("A"),
]);
let value_array = Int32Array::from(vec![
Some(1),
Some(2),
Some(3),
Some(4),
Some(5),
Some(6),
Some(7),
Some(8),
Some(9),
Some(10),
Some(11),
]);
let modified_array = StringArray::from(vec![
Some("2021-02-02"),
Some("2021-02-02"),
Some("2021-02-02"),
Some("2021-02-01"),
Some("2021-02-01"),
Some("2021-02-01"),
Some("2021-02-01"),
Some("2021-02-01"),
Some("2021-02-01"),
Some("2021-02-01"),
Some("2021-02-01"),
]);

RecordBatch::try_new(
schema,
vec![
Arc::new(id_array),
Arc::new(value_array),
Arc::new(modified_array),
Arc::new(StructArray::from(
RecordBatch::try_new(nested_schema, vec![Arc::new(count_array)]).unwrap(),
)),
],
)
.unwrap()
}

pub fn get_delta_schema_with_nested_struct() -> Schema {
Schema::new(vec![
SchemaField::new(
"id".to_string(),
SchemaDataType::primitive("string".to_string()),
true,
HashMap::new(),
),
SchemaField::new(
"value".to_string(),
SchemaDataType::primitive("integer".to_string()),
true,
HashMap::new(),
),
SchemaField::new(
"modified".to_string(),
SchemaDataType::primitive("string".to_string()),
true,
HashMap::new(),
),
SchemaField::new(
String::from("nested"),
SchemaDataType::r#struct(SchemaTypeStruct::new(vec![SchemaField::new(
String::from("count"),
SchemaDataType::primitive(String::from("integer")),
true,
Default::default(),
)])),
true,
Default::default(),
),
])
}

pub fn create_bare_table() -> DeltaTable {
let table_dir = tempfile::tempdir().unwrap();
let table_path = table_dir.path();
Expand Down

0 comments on commit 4a4aaa9

Please sign in to comment.