Skip to content

Commit 001f740

Browse files
committed
arrow-cast: Attempt finding struct field on inconsistent order
1 parent a01a465 commit 001f740

File tree

1 file changed

+64
-10
lines changed

1 file changed

+64
-10
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,34 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
221221
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
222222
) => true,
223223
(Struct(from_fields), Struct(to_fields)) => {
224-
from_fields.len() == to_fields.len()
225-
&& from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
224+
if from_fields.len() != to_fields.len() {
225+
return false;
226+
}
227+
228+
// fast path, all field names are in the same order and same number of fields
229+
if from_fields
230+
.iter()
231+
.zip(to_fields.iter())
232+
.all(|(f1, f2)| f1.name() == f2.name())
233+
{
234+
return from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
226235
// Assume that nullability between two structs are compatible, if not,
227236
// cast kernel will return error.
228237
can_cast_types(f1.data_type(), f2.data_type())
229-
})
238+
});
239+
}
240+
241+
// slow path, we match the fields by name
242+
to_fields.iter().all(|to_field| {
243+
from_fields
244+
.iter()
245+
.find(|from_field| from_field.name() == to_field.name())
246+
.is_some_and(|from_field| {
247+
// Assume that nullability between two structs are compatible, if not,
248+
// cast kernel will return error.
249+
can_cast_types(from_field.data_type(), to_field.data_type())
250+
})
251+
})
230252
}
231253
(Struct(_), _) => false,
232254
(_, Struct(_)) => false,
@@ -1169,14 +1191,46 @@ pub fn cast_with_options(
11691191
cast_options,
11701192
)
11711193
}
1172-
(Struct(_), Struct(to_fields)) => {
1194+
(Struct(from_fields), Struct(to_fields)) => {
11731195
let array = array.as_struct();
1174-
let fields = array
1175-
.columns()
1176-
.iter()
1177-
.zip(to_fields.iter())
1178-
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
1179-
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?;
1196+
1197+
// Fast path: if field names are in the same order, we can just zip and cast
1198+
let fields_match_order = from_fields.len() == to_fields.len()
1199+
&& from_fields
1200+
.iter()
1201+
.zip(to_fields.iter())
1202+
.all(|(f1, f2)| f1.name() == f2.name());
1203+
1204+
let fields = if fields_match_order {
1205+
// Fast path: cast columns in order
1206+
array
1207+
.columns()
1208+
.iter()
1209+
.zip(to_fields.iter())
1210+
.map(|(column, field)| {
1211+
cast_with_options(column, field.data_type(), cast_options)
1212+
})
1213+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1214+
} else {
1215+
// Slow path: match fields by name and reorder
1216+
to_fields
1217+
.iter()
1218+
.map(|to_field| {
1219+
let from_field_idx = from_fields
1220+
.iter()
1221+
.position(|from_field| from_field.name() == to_field.name())
1222+
.ok_or_else(|| {
1223+
ArrowError::CastError(format!(
1224+
"Field '{}' not found in source struct",
1225+
to_field.name()
1226+
))
1227+
})?;
1228+
let column = array.column(from_field_idx);
1229+
cast_with_options(column, to_field.data_type(), cast_options)
1230+
})
1231+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1232+
};
1233+
11801234
let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
11811235
Ok(Arc::new(array) as ArrayRef)
11821236
}

0 commit comments

Comments
 (0)