Skip to content

Commit 7e637a7

Browse files
authored
Struct casting field order (#8871)
# Which issue does this PR close? Closes #8870. # What changes are included in this PR? Check if field order in from/to casting matches, and if not, attempt to find the fields by name. # Are these changes tested? Added unit tests (that previously failed, so I separated them in a commit). # Are there any user-facing changes? No, it's strictly additive functionality. @alamb @vegarsti
1 parent 6ec5d6d commit 7e637a7

File tree

1 file changed

+229
-16
lines changed

1 file changed

+229
-16
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 229 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,34 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
221221
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
222222
) => true,
223223
(Struct(from_fields), Struct(to_fields)) => {
224-
from_fields.len() == to_fields.len()
225-
&& from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
224+
if from_fields.len() != to_fields.len() {
225+
return false;
226+
}
227+
228+
// fast path, all field names are in the same order and same number of fields
229+
if from_fields
230+
.iter()
231+
.zip(to_fields.iter())
232+
.all(|(f1, f2)| f1.name() == f2.name())
233+
{
234+
return from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
226235
// Assume that nullability between two structs are compatible, if not,
227236
// cast kernel will return error.
228237
can_cast_types(f1.data_type(), f2.data_type())
229-
})
238+
});
239+
}
240+
241+
// slow path, we match the fields by name
242+
to_fields.iter().all(|to_field| {
243+
from_fields
244+
.iter()
245+
.find(|from_field| from_field.name() == to_field.name())
246+
.is_some_and(|from_field| {
247+
// Assume that nullability between two structs are compatible, if not,
248+
// cast kernel will return error.
249+
can_cast_types(from_field.data_type(), to_field.data_type())
250+
})
251+
})
230252
}
231253
(Struct(_), _) => false,
232254
(_, Struct(_)) => false,
@@ -1169,14 +1191,46 @@ pub fn cast_with_options(
11691191
cast_options,
11701192
)
11711193
}
1172-
(Struct(_), Struct(to_fields)) => {
1194+
(Struct(from_fields), Struct(to_fields)) => {
11731195
let array = array.as_struct();
1174-
let fields = array
1175-
.columns()
1176-
.iter()
1177-
.zip(to_fields.iter())
1178-
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
1179-
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?;
1196+
1197+
// Fast path: if field names are in the same order, we can just zip and cast
1198+
let fields_match_order = from_fields.len() == to_fields.len()
1199+
&& from_fields
1200+
.iter()
1201+
.zip(to_fields.iter())
1202+
.all(|(f1, f2)| f1.name() == f2.name());
1203+
1204+
let fields = if fields_match_order {
1205+
// Fast path: cast columns in order
1206+
array
1207+
.columns()
1208+
.iter()
1209+
.zip(to_fields.iter())
1210+
.map(|(column, field)| {
1211+
cast_with_options(column, field.data_type(), cast_options)
1212+
})
1213+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1214+
} else {
1215+
// Slow path: match fields by name and reorder
1216+
to_fields
1217+
.iter()
1218+
.map(|to_field| {
1219+
let from_field_idx = from_fields
1220+
.iter()
1221+
.position(|from_field| from_field.name() == to_field.name())
1222+
.ok_or_else(|| {
1223+
ArrowError::CastError(format!(
1224+
"Field '{}' not found in source struct",
1225+
to_field.name()
1226+
))
1227+
})?;
1228+
let column = array.column(from_field_idx);
1229+
cast_with_options(column, to_field.data_type(), cast_options)
1230+
})
1231+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1232+
};
1233+
11801234
let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
11811235
Ok(Arc::new(array) as ArrayRef)
11821236
}
@@ -10836,11 +10890,11 @@ mod tests {
1083610890
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1083710891
let struct_array = StructArray::from(vec![
1083810892
(
10839-
Arc::new(Field::new("b", DataType::Boolean, false)),
10893+
Arc::new(Field::new("a", DataType::Boolean, false)),
1084010894
boolean.clone() as ArrayRef,
1084110895
),
1084210896
(
10843-
Arc::new(Field::new("c", DataType::Int32, false)),
10897+
Arc::new(Field::new("b", DataType::Int32, false)),
1084410898
int.clone() as ArrayRef,
1084510899
),
1084610900
]);
@@ -10884,11 +10938,11 @@ mod tests {
1088410938
let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None]));
1088510939
let struct_array = StructArray::from(vec![
1088610940
(
10887-
Arc::new(Field::new("b", DataType::Boolean, false)),
10941+
Arc::new(Field::new("a", DataType::Boolean, false)),
1088810942
boolean.clone() as ArrayRef,
1088910943
),
1089010944
(
10891-
Arc::new(Field::new("c", DataType::Int32, true)),
10945+
Arc::new(Field::new("b", DataType::Int32, true)),
1089210946
int.clone() as ArrayRef,
1089310947
),
1089410948
]);
@@ -10918,11 +10972,11 @@ mod tests {
1091810972
let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100]));
1091910973
let struct_array = StructArray::from(vec![
1092010974
(
10921-
Arc::new(Field::new("b", DataType::Boolean, false)),
10975+
Arc::new(Field::new("a", DataType::Boolean, false)),
1092210976
boolean.clone() as ArrayRef,
1092310977
),
1092410978
(
10925-
Arc::new(Field::new("c", DataType::Int32, false)),
10979+
Arc::new(Field::new("b", DataType::Int32, false)),
1092610980
int.clone() as ArrayRef,
1092710981
),
1092810982
]);
@@ -10977,6 +11031,165 @@ mod tests {
1097711031
);
1097811032
}
1097911033

11034+
#[test]
11035+
fn test_cast_struct_with_different_field_order() {
11036+
// Test slow path: fields are in different order
11037+
let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
11038+
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
11039+
let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"]));
11040+
11041+
let struct_array = StructArray::from(vec![
11042+
(
11043+
Arc::new(Field::new("a", DataType::Boolean, false)),
11044+
boolean.clone() as ArrayRef,
11045+
),
11046+
(
11047+
Arc::new(Field::new("b", DataType::Int32, false)),
11048+
int.clone() as ArrayRef,
11049+
),
11050+
(
11051+
Arc::new(Field::new("c", DataType::Utf8, false)),
11052+
string.clone() as ArrayRef,
11053+
),
11054+
]);
11055+
11056+
// Target has fields in different order: c, a, b instead of a, b, c
11057+
let to_type = DataType::Struct(
11058+
vec![
11059+
Field::new("c", DataType::Utf8, false),
11060+
Field::new("a", DataType::Utf8, false), // Boolean to Utf8
11061+
Field::new("b", DataType::Utf8, false), // Int32 to Utf8
11062+
]
11063+
.into(),
11064+
);
11065+
11066+
let result = cast(&struct_array, &to_type).unwrap();
11067+
let result_struct = result.as_struct();
11068+
11069+
assert_eq!(result_struct.data_type(), &to_type);
11070+
assert_eq!(result_struct.num_columns(), 3);
11071+
11072+
// Verify field "c" (originally position 2, now position 0) remains Utf8
11073+
let c_column = result_struct.column(0).as_string::<i32>();
11074+
assert_eq!(
11075+
c_column.into_iter().flatten().collect::<Vec<_>>(),
11076+
vec!["foo", "bar", "baz", "qux"]
11077+
);
11078+
11079+
// Verify field "a" (originally position 0, now position 1) was cast from Boolean to Utf8
11080+
let a_column = result_struct.column(1).as_string::<i32>();
11081+
assert_eq!(
11082+
a_column.into_iter().flatten().collect::<Vec<_>>(),
11083+
vec!["false", "false", "true", "true"]
11084+
);
11085+
11086+
// Verify field "b" (originally position 1, now position 2) was cast from Int32 to Utf8
11087+
let b_column = result_struct.column(2).as_string::<i32>();
11088+
assert_eq!(
11089+
b_column.into_iter().flatten().collect::<Vec<_>>(),
11090+
vec!["42", "28", "19", "31"]
11091+
);
11092+
}
11093+
11094+
#[test]
11095+
fn test_cast_struct_with_missing_field() {
11096+
// Test that casting fails when target has a field not present in source
11097+
let boolean = Arc::new(BooleanArray::from(vec![false, true]));
11098+
let struct_array = StructArray::from(vec![(
11099+
Arc::new(Field::new("a", DataType::Boolean, false)),
11100+
boolean.clone() as ArrayRef,
11101+
)]);
11102+
11103+
let to_type = DataType::Struct(
11104+
vec![
11105+
Field::new("a", DataType::Utf8, false),
11106+
Field::new("b", DataType::Int32, false), // Field "b" doesn't exist in source
11107+
]
11108+
.into(),
11109+
);
11110+
11111+
let result = cast(&struct_array, &to_type);
11112+
assert!(result.is_err());
11113+
assert_eq!(
11114+
result.unwrap_err().to_string(),
11115+
"Cast error: Field 'b' not found in source struct"
11116+
);
11117+
}
11118+
11119+
#[test]
11120+
fn test_cast_struct_with_subset_of_fields() {
11121+
// Test casting to a struct with fewer fields (selecting a subset)
11122+
let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
11123+
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
11124+
let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"]));
11125+
11126+
let struct_array = StructArray::from(vec![
11127+
(
11128+
Arc::new(Field::new("a", DataType::Boolean, false)),
11129+
boolean.clone() as ArrayRef,
11130+
),
11131+
(
11132+
Arc::new(Field::new("b", DataType::Int32, false)),
11133+
int.clone() as ArrayRef,
11134+
),
11135+
(
11136+
Arc::new(Field::new("c", DataType::Utf8, false)),
11137+
string.clone() as ArrayRef,
11138+
),
11139+
]);
11140+
11141+
// Target has only fields "c" and "a", omitting "b"
11142+
let to_type = DataType::Struct(
11143+
vec![
11144+
Field::new("c", DataType::Utf8, false),
11145+
Field::new("a", DataType::Utf8, false),
11146+
]
11147+
.into(),
11148+
);
11149+
11150+
let result = cast(&struct_array, &to_type).unwrap();
11151+
let result_struct = result.as_struct();
11152+
11153+
assert_eq!(result_struct.data_type(), &to_type);
11154+
assert_eq!(result_struct.num_columns(), 2);
11155+
11156+
// Verify field "c" remains Utf8
11157+
let c_column = result_struct.column(0).as_string::<i32>();
11158+
assert_eq!(
11159+
c_column.into_iter().flatten().collect::<Vec<_>>(),
11160+
vec!["foo", "bar", "baz", "qux"]
11161+
);
11162+
11163+
// Verify field "a" was cast from Boolean to Utf8
11164+
let a_column = result_struct.column(1).as_string::<i32>();
11165+
assert_eq!(
11166+
a_column.into_iter().flatten().collect::<Vec<_>>(),
11167+
vec!["false", "false", "true", "true"]
11168+
);
11169+
}
11170+
11171+
#[test]
11172+
fn test_can_cast_struct_with_missing_field() {
11173+
// Test that can_cast_types returns false when target has a field not in source
11174+
let from_type = DataType::Struct(
11175+
vec![
11176+
Field::new("a", DataType::Int32, false),
11177+
Field::new("b", DataType::Utf8, false),
11178+
]
11179+
.into(),
11180+
);
11181+
11182+
let to_type = DataType::Struct(
11183+
vec![
11184+
Field::new("a", DataType::Int64, false),
11185+
Field::new("c", DataType::Boolean, false), // Field "c" not in source
11186+
]
11187+
.into(),
11188+
);
11189+
11190+
assert!(!can_cast_types(&from_type, &to_type));
11191+
}
11192+
1098011193
fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) {
1098111194
run_decimal_cast_test_case::<Decimal128Type, Decimal128Type>(t.clone());
1098211195
run_decimal_cast_test_case::<Decimal128Type, Decimal256Type>(t.clone());

0 commit comments

Comments
 (0)