Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 229 additions & 16 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,34 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
) => true,
(Struct(from_fields), Struct(to_fields)) => {
from_fields.len() == to_fields.len()
&& from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
if from_fields.len() != to_fields.len() {
return false;
}

// fast path, all field names are in the same order and same number of fields
if from_fields
.iter()
.zip(to_fields.iter())
.all(|(f1, f2)| f1.name() == f2.name())
{
return from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
// Assume that nullability between two structs are compatible, if not,
// cast kernel will return error.
can_cast_types(f1.data_type(), f2.data_type())
})
});
}

// slow path, we match the fields by name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one idea that has come up in the past is to do this mapping calculation once and then use it for both can_cast_types and cast

However, this seems to be strictly better than current main (doesn't slow down existing code and allows more uses, so 👍 to me)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also going back and forth, but I decided that an additional allocation in the average case would be far worse in perf cost than comparisons. We're heavily profiling all these code paths, if it's significant, we will come back and improve the perf! 😄

to_fields.iter().all(|to_field| {
from_fields
.iter()
.find(|from_field| from_field.name() == to_field.name())
.is_some_and(|from_field| {
// Assume that nullability between two structs are compatible, if not,
// cast kernel will return error.
can_cast_types(from_field.data_type(), to_field.data_type())
})
})
}
(Struct(_), _) => false,
(_, Struct(_)) => false,
Expand Down Expand Up @@ -1169,14 +1191,46 @@ pub fn cast_with_options(
cast_options,
)
}
(Struct(_), Struct(to_fields)) => {
(Struct(from_fields), Struct(to_fields)) => {
let array = array.as_struct();
let fields = array
.columns()
.iter()
.zip(to_fields.iter())
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?;

// Fast path: if field names are in the same order, we can just zip and cast
let fields_match_order = from_fields.len() == to_fields.len()
&& from_fields
.iter()
.zip(to_fields.iter())
.all(|(f1, f2)| f1.name() == f2.name());

let fields = if fields_match_order {
// Fast path: cast columns in order
array
.columns()
.iter()
.zip(to_fields.iter())
.map(|(column, field)| {
cast_with_options(column, field.data_type(), cast_options)
})
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
} else {
// Slow path: match fields by name and reorder
to_fields
.iter()
.map(|to_field| {
let from_field_idx = from_fields
.iter()
.position(|from_field| from_field.name() == to_field.name())
.ok_or_else(|| {
ArrowError::CastError(format!(
"Field '{}' not found in source struct",
to_field.name()
))
})?;
let column = array.column(from_field_idx);
cast_with_options(column, to_field.data_type(), cast_options)
})
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
};

let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
Ok(Arc::new(array) as ArrayRef)
}
Expand Down Expand Up @@ -10836,11 +10890,11 @@ mod tests {
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Boolean, false)),
Arc::new(Field::new("a", DataType::Boolean, false)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out these tests were actually wrong to begin with, have a look at the names of the columns, how can a/b be cast to b/c? They only ever worked by accident, and now that we test whether they match, they needed to be fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would other people find this a regression (aka that they expect struct fields to be treated in order, rather than by name) 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s possible someone had it incorrect with it accidentally working but I don’t think that should stop fixing what’s clearly a bug. The previous behavior can actually cause very hidden unexpected behavior when it does accidentally work but field names mismatch or have a different order. I can’t see a valid use case for the incorrect previous behavior and all valid behaviors can be represented still and on total mismatches a user will even get a proper error now instead of potentially silently continuing.

boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, false)),
Arc::new(Field::new("b", DataType::Int32, false)),
int.clone() as ArrayRef,
),
]);
Expand Down Expand Up @@ -10884,11 +10938,11 @@ mod tests {
let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Boolean, false)),
Arc::new(Field::new("a", DataType::Boolean, false)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise why was this test changed?

boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, true)),
Arc::new(Field::new("b", DataType::Int32, true)),
int.clone() as ArrayRef,
),
]);
Expand Down Expand Up @@ -10918,11 +10972,11 @@ mod tests {
let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("b", DataType::Boolean, false)),
Arc::new(Field::new("a", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Int32, false)),
Arc::new(Field::new("b", DataType::Int32, false)),
int.clone() as ArrayRef,
),
]);
Expand Down Expand Up @@ -10977,6 +11031,165 @@ mod tests {
);
}

#[test]
fn test_cast_struct_with_different_field_order() {
// Test slow path: fields are in different order
let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"]));

let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("a", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("b", DataType::Int32, false)),
int.clone() as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Utf8, false)),
string.clone() as ArrayRef,
),
]);

// Target has fields in different order: c, a, b instead of a, b, c
let to_type = DataType::Struct(
vec![
Field::new("c", DataType::Utf8, false),
Field::new("a", DataType::Utf8, false), // Boolean to Utf8
Field::new("b", DataType::Utf8, false), // Int32 to Utf8
]
.into(),
);

let result = cast(&struct_array, &to_type).unwrap();
let result_struct = result.as_struct();

assert_eq!(result_struct.data_type(), &to_type);
assert_eq!(result_struct.num_columns(), 3);

// Verify field "c" (originally position 2, now position 0) remains Utf8
let c_column = result_struct.column(0).as_string::<i32>();
assert_eq!(
c_column.into_iter().flatten().collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "qux"]
);

// Verify field "a" (originally position 0, now position 1) was cast from Boolean to Utf8
let a_column = result_struct.column(1).as_string::<i32>();
assert_eq!(
a_column.into_iter().flatten().collect::<Vec<_>>(),
vec!["false", "false", "true", "true"]
);

// Verify field "b" (originally position 1, now position 2) was cast from Int32 to Utf8
let b_column = result_struct.column(2).as_string::<i32>();
assert_eq!(
b_column.into_iter().flatten().collect::<Vec<_>>(),
vec!["42", "28", "19", "31"]
);
}

#[test]
fn test_cast_struct_with_missing_field() {
// Test that casting fails when target has a field not present in source
let boolean = Arc::new(BooleanArray::from(vec![false, true]));
let struct_array = StructArray::from(vec![(
Arc::new(Field::new("a", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
)]);

let to_type = DataType::Struct(
vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false), // Field "b" doesn't exist in source
]
.into(),
);

let result = cast(&struct_array, &to_type);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Cast error: Field 'b' not found in source struct"
);
}

#[test]
fn test_cast_struct_with_subset_of_fields() {
// Test casting to a struct with fewer fields (selecting a subset)
let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"]));

let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("a", DataType::Boolean, false)),
boolean.clone() as ArrayRef,
),
(
Arc::new(Field::new("b", DataType::Int32, false)),
int.clone() as ArrayRef,
),
(
Arc::new(Field::new("c", DataType::Utf8, false)),
string.clone() as ArrayRef,
),
]);

// Target has only fields "c" and "a", omitting "b"
let to_type = DataType::Struct(
vec![
Field::new("c", DataType::Utf8, false),
Field::new("a", DataType::Utf8, false),
]
.into(),
);

let result = cast(&struct_array, &to_type).unwrap();
let result_struct = result.as_struct();

assert_eq!(result_struct.data_type(), &to_type);
assert_eq!(result_struct.num_columns(), 2);

// Verify field "c" remains Utf8
let c_column = result_struct.column(0).as_string::<i32>();
assert_eq!(
c_column.into_iter().flatten().collect::<Vec<_>>(),
vec!["foo", "bar", "baz", "qux"]
);

// Verify field "a" was cast from Boolean to Utf8
let a_column = result_struct.column(1).as_string::<i32>();
assert_eq!(
a_column.into_iter().flatten().collect::<Vec<_>>(),
vec!["false", "false", "true", "true"]
);
}

#[test]
fn test_can_cast_struct_with_missing_field() {
// Test that can_cast_types returns false when target has a field not in source
let from_type = DataType::Struct(
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]
.into(),
);

let to_type = DataType::Struct(
vec![
Field::new("a", DataType::Int64, false),
Field::new("c", DataType::Boolean, false), // Field "c" not in source
]
.into(),
);

assert!(!can_cast_types(&from_type, &to_type));
}

fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) {
run_decimal_cast_test_case::<Decimal128Type, Decimal128Type>(t.clone());
run_decimal_cast_test_case::<Decimal128Type, Decimal256Type>(t.clone());
Expand Down
Loading