@@ -221,12 +221,34 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
221221 Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) | Decimal256 ( _, _) ,
222222 ) => true ,
223223 ( Struct ( from_fields) , Struct ( to_fields) ) => {
224- from_fields. len ( ) == to_fields. len ( )
225- && from_fields. iter ( ) . zip ( to_fields. iter ( ) ) . all ( |( f1, f2) | {
224+ if from_fields. len ( ) != to_fields. len ( ) {
225+ return false ;
226+ }
227+
228+ // fast path, all field names are in the same order and same number of fields
229+ if from_fields
230+ . iter ( )
231+ . zip ( to_fields. iter ( ) )
232+ . all ( |( f1, f2) | f1. name ( ) == f2. name ( ) )
233+ {
234+ return from_fields. iter ( ) . zip ( to_fields. iter ( ) ) . all ( |( f1, f2) | {
226235 // Assume that nullability between two structs are compatible, if not,
227236 // cast kernel will return error.
228237 can_cast_types ( f1. data_type ( ) , f2. data_type ( ) )
229- } )
238+ } ) ;
239+ }
240+
241+ // slow path, we match the fields by name
242+ to_fields. iter ( ) . all ( |to_field| {
243+ from_fields
244+ . iter ( )
245+ . find ( |from_field| from_field. name ( ) == to_field. name ( ) )
246+ . is_some_and ( |from_field| {
247+ // Assume that nullability between two structs are compatible, if not,
248+ // cast kernel will return error.
249+ can_cast_types ( from_field. data_type ( ) , to_field. data_type ( ) )
250+ } )
251+ } )
230252 }
231253 ( Struct ( _) , _) => false ,
232254 ( _, Struct ( _) ) => false ,
@@ -1169,14 +1191,46 @@ pub fn cast_with_options(
11691191 cast_options,
11701192 )
11711193 }
1172- ( Struct ( _ ) , Struct ( to_fields) ) => {
1194+ ( Struct ( from_fields ) , Struct ( to_fields) ) => {
11731195 let array = array. as_struct ( ) ;
1174- let fields = array
1175- . columns ( )
1176- . iter ( )
1177- . zip ( to_fields. iter ( ) )
1178- . map ( |( l, field) | cast_with_options ( l, field. data_type ( ) , cast_options) )
1179- . collect :: < Result < Vec < ArrayRef > , ArrowError > > ( ) ?;
1196+
1197+ // Fast path: if field names are in the same order, we can just zip and cast
1198+ let fields_match_order = from_fields. len ( ) == to_fields. len ( )
1199+ && from_fields
1200+ . iter ( )
1201+ . zip ( to_fields. iter ( ) )
1202+ . all ( |( f1, f2) | f1. name ( ) == f2. name ( ) ) ;
1203+
1204+ let fields = if fields_match_order {
1205+ // Fast path: cast columns in order
1206+ array
1207+ . columns ( )
1208+ . iter ( )
1209+ . zip ( to_fields. iter ( ) )
1210+ . map ( |( column, field) | {
1211+ cast_with_options ( column, field. data_type ( ) , cast_options)
1212+ } )
1213+ . collect :: < Result < Vec < ArrayRef > , ArrowError > > ( ) ?
1214+ } else {
1215+ // Slow path: match fields by name and reorder
1216+ to_fields
1217+ . iter ( )
1218+ . map ( |to_field| {
1219+ let from_field_idx = from_fields
1220+ . iter ( )
1221+ . position ( |from_field| from_field. name ( ) == to_field. name ( ) )
1222+ . ok_or_else ( || {
1223+ ArrowError :: CastError ( format ! (
1224+ "Field '{}' not found in source struct" ,
1225+ to_field. name( )
1226+ ) )
1227+ } ) ?;
1228+ let column = array. column ( from_field_idx) ;
1229+ cast_with_options ( column, to_field. data_type ( ) , cast_options)
1230+ } )
1231+ . collect :: < Result < Vec < ArrayRef > , ArrowError > > ( ) ?
1232+ } ;
1233+
11801234 let array = StructArray :: try_new ( to_fields. clone ( ) , fields, array. nulls ( ) . cloned ( ) ) ?;
11811235 Ok ( Arc :: new ( array) as ArrayRef )
11821236 }
@@ -10836,11 +10890,11 @@ mod tests {
1083610890 let int = Arc :: new ( Int32Array :: from ( vec ! [ 42 , 28 , 19 , 31 ] ) ) ;
1083710891 let struct_array = StructArray :: from ( vec ! [
1083810892 (
10839- Arc :: new( Field :: new( "b " , DataType :: Boolean , false ) ) ,
10893+ Arc :: new( Field :: new( "a " , DataType :: Boolean , false ) ) ,
1084010894 boolean. clone( ) as ArrayRef ,
1084110895 ) ,
1084210896 (
10843- Arc :: new( Field :: new( "c " , DataType :: Int32 , false ) ) ,
10897+ Arc :: new( Field :: new( "b " , DataType :: Int32 , false ) ) ,
1084410898 int. clone( ) as ArrayRef ,
1084510899 ) ,
1084610900 ] ) ;
@@ -10884,11 +10938,11 @@ mod tests {
1088410938 let int = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 42 ) , None , Some ( 19 ) , None ] ) ) ;
1088510939 let struct_array = StructArray :: from ( vec ! [
1088610940 (
10887- Arc :: new( Field :: new( "b " , DataType :: Boolean , false ) ) ,
10941+ Arc :: new( Field :: new( "a " , DataType :: Boolean , false ) ) ,
1088810942 boolean. clone( ) as ArrayRef ,
1088910943 ) ,
1089010944 (
10891- Arc :: new( Field :: new( "c " , DataType :: Int32 , true ) ) ,
10945+ Arc :: new( Field :: new( "b " , DataType :: Int32 , true ) ) ,
1089210946 int. clone( ) as ArrayRef ,
1089310947 ) ,
1089410948 ] ) ;
@@ -10918,11 +10972,11 @@ mod tests {
1091810972 let int = Arc :: new ( Int32Array :: from ( vec ! [ i32 :: MAX , 25 , 1 , 100 ] ) ) ;
1091910973 let struct_array = StructArray :: from ( vec ! [
1092010974 (
10921- Arc :: new( Field :: new( "b " , DataType :: Boolean , false ) ) ,
10975+ Arc :: new( Field :: new( "a " , DataType :: Boolean , false ) ) ,
1092210976 boolean. clone( ) as ArrayRef ,
1092310977 ) ,
1092410978 (
10925- Arc :: new( Field :: new( "c " , DataType :: Int32 , false ) ) ,
10979+ Arc :: new( Field :: new( "b " , DataType :: Int32 , false ) ) ,
1092610980 int. clone( ) as ArrayRef ,
1092710981 ) ,
1092810982 ] ) ;
@@ -10977,6 +11031,165 @@ mod tests {
1097711031 ) ;
1097811032 }
1097911033
11034+ #[ test]
11035+ fn test_cast_struct_with_different_field_order ( ) {
11036+ // Test slow path: fields are in different order
11037+ let boolean = Arc :: new ( BooleanArray :: from ( vec ! [ false , false , true , true ] ) ) ;
11038+ let int = Arc :: new ( Int32Array :: from ( vec ! [ 42 , 28 , 19 , 31 ] ) ) ;
11039+ let string = Arc :: new ( StringArray :: from ( vec ! [ "foo" , "bar" , "baz" , "qux" ] ) ) ;
11040+
11041+ let struct_array = StructArray :: from ( vec ! [
11042+ (
11043+ Arc :: new( Field :: new( "a" , DataType :: Boolean , false ) ) ,
11044+ boolean. clone( ) as ArrayRef ,
11045+ ) ,
11046+ (
11047+ Arc :: new( Field :: new( "b" , DataType :: Int32 , false ) ) ,
11048+ int. clone( ) as ArrayRef ,
11049+ ) ,
11050+ (
11051+ Arc :: new( Field :: new( "c" , DataType :: Utf8 , false ) ) ,
11052+ string. clone( ) as ArrayRef ,
11053+ ) ,
11054+ ] ) ;
11055+
11056+ // Target has fields in different order: c, a, b instead of a, b, c
11057+ let to_type = DataType :: Struct (
11058+ vec ! [
11059+ Field :: new( "c" , DataType :: Utf8 , false ) ,
11060+ Field :: new( "a" , DataType :: Utf8 , false ) , // Boolean to Utf8
11061+ Field :: new( "b" , DataType :: Utf8 , false ) , // Int32 to Utf8
11062+ ]
11063+ . into ( ) ,
11064+ ) ;
11065+
11066+ let result = cast ( & struct_array, & to_type) . unwrap ( ) ;
11067+ let result_struct = result. as_struct ( ) ;
11068+
11069+ assert_eq ! ( result_struct. data_type( ) , & to_type) ;
11070+ assert_eq ! ( result_struct. num_columns( ) , 3 ) ;
11071+
11072+ // Verify field "c" (originally position 2, now position 0) remains Utf8
11073+ let c_column = result_struct. column ( 0 ) . as_string :: < i32 > ( ) ;
11074+ assert_eq ! (
11075+ c_column. into_iter( ) . flatten( ) . collect:: <Vec <_>>( ) ,
11076+ vec![ "foo" , "bar" , "baz" , "qux" ]
11077+ ) ;
11078+
11079+ // Verify field "a" (originally position 0, now position 1) was cast from Boolean to Utf8
11080+ let a_column = result_struct. column ( 1 ) . as_string :: < i32 > ( ) ;
11081+ assert_eq ! (
11082+ a_column. into_iter( ) . flatten( ) . collect:: <Vec <_>>( ) ,
11083+ vec![ "false" , "false" , "true" , "true" ]
11084+ ) ;
11085+
11086+ // Verify field "b" (originally position 1, now position 2) was cast from Int32 to Utf8
11087+ let b_column = result_struct. column ( 2 ) . as_string :: < i32 > ( ) ;
11088+ assert_eq ! (
11089+ b_column. into_iter( ) . flatten( ) . collect:: <Vec <_>>( ) ,
11090+ vec![ "42" , "28" , "19" , "31" ]
11091+ ) ;
11092+ }
11093+
11094+ #[ test]
11095+ fn test_cast_struct_with_missing_field ( ) {
11096+ // Test that casting fails when target has a field not present in source
11097+ let boolean = Arc :: new ( BooleanArray :: from ( vec ! [ false , true ] ) ) ;
11098+ let struct_array = StructArray :: from ( vec ! [ (
11099+ Arc :: new( Field :: new( "a" , DataType :: Boolean , false ) ) ,
11100+ boolean. clone( ) as ArrayRef ,
11101+ ) ] ) ;
11102+
11103+ let to_type = DataType :: Struct (
11104+ vec ! [
11105+ Field :: new( "a" , DataType :: Utf8 , false ) ,
11106+ Field :: new( "b" , DataType :: Int32 , false ) , // Field "b" doesn't exist in source
11107+ ]
11108+ . into ( ) ,
11109+ ) ;
11110+
11111+ let result = cast ( & struct_array, & to_type) ;
11112+ assert ! ( result. is_err( ) ) ;
11113+ assert_eq ! (
11114+ result. unwrap_err( ) . to_string( ) ,
11115+ "Cast error: Field 'b' not found in source struct"
11116+ ) ;
11117+ }
11118+
11119+ #[ test]
11120+ fn test_cast_struct_with_subset_of_fields ( ) {
11121+ // Test casting to a struct with fewer fields (selecting a subset)
11122+ let boolean = Arc :: new ( BooleanArray :: from ( vec ! [ false , false , true , true ] ) ) ;
11123+ let int = Arc :: new ( Int32Array :: from ( vec ! [ 42 , 28 , 19 , 31 ] ) ) ;
11124+ let string = Arc :: new ( StringArray :: from ( vec ! [ "foo" , "bar" , "baz" , "qux" ] ) ) ;
11125+
11126+ let struct_array = StructArray :: from ( vec ! [
11127+ (
11128+ Arc :: new( Field :: new( "a" , DataType :: Boolean , false ) ) ,
11129+ boolean. clone( ) as ArrayRef ,
11130+ ) ,
11131+ (
11132+ Arc :: new( Field :: new( "b" , DataType :: Int32 , false ) ) ,
11133+ int. clone( ) as ArrayRef ,
11134+ ) ,
11135+ (
11136+ Arc :: new( Field :: new( "c" , DataType :: Utf8 , false ) ) ,
11137+ string. clone( ) as ArrayRef ,
11138+ ) ,
11139+ ] ) ;
11140+
11141+ // Target has only fields "c" and "a", omitting "b"
11142+ let to_type = DataType :: Struct (
11143+ vec ! [
11144+ Field :: new( "c" , DataType :: Utf8 , false ) ,
11145+ Field :: new( "a" , DataType :: Utf8 , false ) ,
11146+ ]
11147+ . into ( ) ,
11148+ ) ;
11149+
11150+ let result = cast ( & struct_array, & to_type) . unwrap ( ) ;
11151+ let result_struct = result. as_struct ( ) ;
11152+
11153+ assert_eq ! ( result_struct. data_type( ) , & to_type) ;
11154+ assert_eq ! ( result_struct. num_columns( ) , 2 ) ;
11155+
11156+ // Verify field "c" remains Utf8
11157+ let c_column = result_struct. column ( 0 ) . as_string :: < i32 > ( ) ;
11158+ assert_eq ! (
11159+ c_column. into_iter( ) . flatten( ) . collect:: <Vec <_>>( ) ,
11160+ vec![ "foo" , "bar" , "baz" , "qux" ]
11161+ ) ;
11162+
11163+ // Verify field "a" was cast from Boolean to Utf8
11164+ let a_column = result_struct. column ( 1 ) . as_string :: < i32 > ( ) ;
11165+ assert_eq ! (
11166+ a_column. into_iter( ) . flatten( ) . collect:: <Vec <_>>( ) ,
11167+ vec![ "false" , "false" , "true" , "true" ]
11168+ ) ;
11169+ }
11170+
11171+ #[ test]
11172+ fn test_can_cast_struct_with_missing_field ( ) {
11173+ // Test that can_cast_types returns false when target has a field not in source
11174+ let from_type = DataType :: Struct (
11175+ vec ! [
11176+ Field :: new( "a" , DataType :: Int32 , false ) ,
11177+ Field :: new( "b" , DataType :: Utf8 , false ) ,
11178+ ]
11179+ . into ( ) ,
11180+ ) ;
11181+
11182+ let to_type = DataType :: Struct (
11183+ vec ! [
11184+ Field :: new( "a" , DataType :: Int64 , false ) ,
11185+ Field :: new( "c" , DataType :: Boolean , false ) , // Field "c" not in source
11186+ ]
11187+ . into ( ) ,
11188+ ) ;
11189+
11190+ assert ! ( !can_cast_types( & from_type, & to_type) ) ;
11191+ }
11192+
1098011193 fn run_decimal_cast_test_case_between_multiple_types ( t : DecimalCastTestConfig ) {
1098111194 run_decimal_cast_test_case :: < Decimal128Type , Decimal128Type > ( t. clone ( ) ) ;
1098211195 run_decimal_cast_test_case :: < Decimal128Type , Decimal256Type > ( t. clone ( ) ) ;
0 commit comments