@@ -34,7 +34,7 @@ use arrow::{
3434} ;
3535use arrow_array:: builder:: StringBuilder ;
3636use arrow_array:: { DictionaryArray , StringArray , StructArray } ;
37- use arrow_schema:: { DataType , Schema } ;
37+ use arrow_schema:: { DataType , Field , Schema } ;
3838use datafusion_common:: {
3939 cast:: as_generic_string_array, internal_err, Result as DataFusionResult , ScalarValue ,
4040} ;
@@ -714,6 +714,14 @@ fn cast_array(
714714 ( DataType :: Struct ( _) , DataType :: Utf8 ) => {
715715 Ok ( casts_struct_to_string ( array. as_struct ( ) , & timezone) ?)
716716 }
717+ ( DataType :: Struct ( _) , DataType :: Struct ( _) ) => Ok ( cast_struct_to_struct (
718+ array. as_struct ( ) ,
719+ from_type,
720+ to_type,
721+ eval_mode,
722+ timezone,
723+ allow_incompat,
724+ ) ?) ,
717725 _ if is_datafusion_spark_compatible ( from_type, to_type, allow_incompat) => {
718726 // use DataFusion cast only when we know that it is compatible with Spark
719727 Ok ( cast_with_options ( & array, to_type, & CAST_OPTIONS ) ?)
@@ -811,6 +819,35 @@ fn is_datafusion_spark_compatible(
811819 }
812820}
813821
822+ /// Cast between struct types based on logic in
823+ /// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
824+ fn cast_struct_to_struct (
825+ array : & StructArray ,
826+ from_type : & DataType ,
827+ to_type : & DataType ,
828+ eval_mode : EvalMode ,
829+ timezone : String ,
830+ allow_incompat : bool ,
831+ ) -> DataFusionResult < ArrayRef > {
832+ match ( from_type, to_type) {
833+ ( DataType :: Struct ( _) , DataType :: Struct ( to_fields) ) => {
834+ let mut cast_fields: Vec < ( Arc < Field > , ArrayRef ) > = Vec :: with_capacity ( to_fields. len ( ) ) ;
835+ for i in 0 ..to_fields. len ( ) {
836+ let cast_field = cast_array (
837+ Arc :: clone ( array. column ( i) ) ,
838+ to_fields[ i] . data_type ( ) ,
839+ eval_mode,
840+ timezone. clone ( ) ,
841+ allow_incompat,
842+ ) ?;
843+ cast_fields. push ( ( Arc :: clone ( & to_fields[ i] ) , cast_field) ) ;
844+ }
845+ Ok ( Arc :: new ( StructArray :: from ( cast_fields) ) )
846+ }
847+ _ => unreachable ! ( ) ,
848+ }
849+ }
850+
814851fn casts_struct_to_string ( array : & StructArray , timezone : & str ) -> DataFusionResult < ArrayRef > {
815852 // cast each field to a string
816853 let string_arrays: Vec < ArrayRef > = array
@@ -1929,7 +1966,7 @@ fn trim_end(s: &str) -> &str {
19291966mod tests {
19301967 use arrow:: datatypes:: TimestampMicrosecondType ;
19311968 use arrow_array:: StringArray ;
1932- use arrow_schema:: { Field , TimeUnit } ;
1969+ use arrow_schema:: { Field , Fields , TimeUnit } ;
19331970 use std:: str:: FromStr ;
19341971
19351972 use super :: * ;
@@ -2336,4 +2373,75 @@ mod tests {
23362373 assert_eq ! ( r#"{4, d}"# , string_array. value( 3 ) ) ;
23372374 assert_eq ! ( r#"{5, e}"# , string_array. value( 4 ) ) ;
23382375 }
2376+
2377+ #[ test]
2378+ fn test_cast_struct_to_struct ( ) {
2379+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [
2380+ Some ( 1 ) ,
2381+ Some ( 2 ) ,
2382+ None ,
2383+ Some ( 4 ) ,
2384+ Some ( 5 ) ,
2385+ ] ) ) ;
2386+ let b: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" , "d" , "e" ] ) ) ;
2387+ let c: ArrayRef = Arc :: new ( StructArray :: from ( vec ! [
2388+ ( Arc :: new( Field :: new( "a" , DataType :: Int32 , true ) ) , a) ,
2389+ ( Arc :: new( Field :: new( "b" , DataType :: Utf8 , true ) ) , b) ,
2390+ ] ) ) ;
2391+ // change type of "a" from Int32 to Utf8
2392+ let fields = Fields :: from ( vec ! [
2393+ Field :: new( "a" , DataType :: Utf8 , true ) ,
2394+ Field :: new( "b" , DataType :: Utf8 , true ) ,
2395+ ] ) ;
2396+ let cast_array = spark_cast (
2397+ ColumnarValue :: Array ( c) ,
2398+ & DataType :: Struct ( fields) ,
2399+ EvalMode :: Legacy ,
2400+ "UTC" ,
2401+ false ,
2402+ )
2403+ . unwrap ( ) ;
2404+ if let ColumnarValue :: Array ( cast_array) = cast_array {
2405+ assert_eq ! ( 5 , cast_array. len( ) ) ;
2406+ let a = cast_array. as_struct ( ) . column ( 0 ) . as_string :: < i32 > ( ) ;
2407+ assert_eq ! ( "1" , a. value( 0 ) ) ;
2408+ } else {
2409+ unreachable ! ( )
2410+ }
2411+ }
2412+
2413+ #[ test]
2414+ fn test_cast_struct_to_struct_drop_column ( ) {
2415+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [
2416+ Some ( 1 ) ,
2417+ Some ( 2 ) ,
2418+ None ,
2419+ Some ( 4 ) ,
2420+ Some ( 5 ) ,
2421+ ] ) ) ;
2422+ let b: ArrayRef = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" , "d" , "e" ] ) ) ;
2423+ let c: ArrayRef = Arc :: new ( StructArray :: from ( vec ! [
2424+ ( Arc :: new( Field :: new( "a" , DataType :: Int32 , true ) ) , a) ,
2425+ ( Arc :: new( Field :: new( "b" , DataType :: Utf8 , true ) ) , b) ,
2426+ ] ) ) ;
2427+ // change type of "a" from Int32 to Utf8 and drop "b"
2428+ let fields = Fields :: from ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ;
2429+ let cast_array = spark_cast (
2430+ ColumnarValue :: Array ( c) ,
2431+ & DataType :: Struct ( fields) ,
2432+ EvalMode :: Legacy ,
2433+ "UTC" ,
2434+ false ,
2435+ )
2436+ . unwrap ( ) ;
2437+ if let ColumnarValue :: Array ( cast_array) = cast_array {
2438+ assert_eq ! ( 5 , cast_array. len( ) ) ;
2439+ let struct_array = cast_array. as_struct ( ) ;
2440+ assert_eq ! ( 1 , struct_array. columns( ) . len( ) ) ;
2441+ let a = struct_array. column ( 0 ) . as_string :: < i32 > ( ) ;
2442+ assert_eq ! ( "1" , a. value( 0 ) ) ;
2443+ } else {
2444+ unreachable ! ( )
2445+ }
2446+ }
23392447}
0 commit comments