Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cast list items to default before write with different item names #1959

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Extend cast condition for lists
  • Loading branch information
Jonas Schmitz committed Dec 12, 2023
commit 0d08134c160c53506446ce8046747e8a6b578b70
85 changes: 84 additions & 1 deletion crates/deltalake-core/src/operations/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn cast_record_batch_columns(
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();

if let (DataType::Struct(_), DataType::Struct(child_fields)) =
(col.data_type(), f.data_type())
{
Expand All @@ -28,7 +29,7 @@ fn cast_record_batch_columns(
child_columns.clone(),
None,
)) as ArrayRef)
} else if !col.data_type().equals_datatype(f.data_type()) {
} else if is_cast_required(col.data_type(), f.data_type()) {
cast_with_options(col, f.data_type(), cast_options)
} else {
Ok(col.clone())
Expand All @@ -37,6 +38,18 @@ fn cast_record_batch_columns(
.collect::<Result<Vec<_>, _>>()
}

fn is_cast_required(a: &DataType, b: &DataType) -> bool {
match (a, b) {
(DataType::List(a_item), DataType::List(b_item)) => {
// If list item name is not the default('item') the list must be casted
!a.equals_datatype(b) || a_item.name() != b_item.name()
}
(_, _) => {
!a.equals_datatype(b)
}
}
}

/// Cast recordbatch to a new target_schema, by casting each column array
pub fn cast_record_batch(
batch: &RecordBatch,
Expand All @@ -51,3 +64,73 @@ pub fn cast_record_batch(
let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}


#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::ArrayData;
use arrow_array::{Array, ArrayRef, ListArray, RecordBatch};
use arrow_buffer::Buffer;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
use crate::operations::cast::{cast_record_batch, is_cast_required};

#[test]
fn test_cast_record_batch_with_list_non_default_item(){
let array = Arc::new(make_list_array()) as ArrayRef;
let source_schema = Schema::new(vec![
Field::new("list_column", array.data_type().clone(), false)
]);
let record_batch = RecordBatch::try_new(Arc::new(source_schema), vec![array]).unwrap();

let fields = Fields::from(vec![Field::new_list("list_column", Field::new("item", DataType::Int8, false), false)]);
let target_schema = Arc::new(Schema::new(fields)) as SchemaRef;

let result = cast_record_batch(&record_batch, target_schema, false);

let schema = result.unwrap().schema();
let field = schema.column_with_name("list_column").unwrap().1;
if let DataType::List(list_item) = field.data_type() {
assert_eq!(list_item.name(), "item");
} else {
panic!("Not a list");
}
}

fn make_list_array() -> ListArray {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
.build()
.unwrap();

let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);

let list_data_type =
DataType::List(Arc::new(Field::new("element", DataType::Int32, true)));
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_buffer(value_offsets)
.add_child_data(value_data)
.build()
.unwrap();
ListArray::from(list_data)
}

#[test]
fn test_is_cast_required_with_list() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(!is_cast_required(&field1, &field2));
}

#[test]
fn test_is_cast_required_with_list_non_default_item() {
let field1 = DataType::List(FieldRef::from(Field::new("element", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(is_cast_required(&field1, &field2));
}

}