@@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> {
327327
328328// TODO Move the rest inside of BinaryTypeCoercer
329329
330+ fn is_decimal ( data_type : & DataType ) -> bool {
331+ matches ! (
332+ data_type,
333+ DataType :: Decimal32 ( ..)
334+ | DataType :: Decimal64 ( ..)
335+ | DataType :: Decimal128 ( ..)
336+ | DataType :: Decimal256 ( ..)
337+ )
338+ }
339+
330340/// Coercion rules for mathematics operators between decimal and non-decimal types.
331341fn math_decimal_coercion (
332342 lhs_type : & DataType ,
@@ -357,6 +367,15 @@ fn math_decimal_coercion(
357367 | ( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
358368 Some ( ( lhs_type. clone ( ) , rhs_type. clone ( ) ) )
359369 }
370+ // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
371+ ( lhs, rhs)
372+ if is_decimal ( lhs)
373+ && is_decimal ( rhs)
374+ && std:: mem:: discriminant ( lhs) != std:: mem:: discriminant ( rhs) =>
375+ {
376+ let coerced_type = get_wider_decimal_type_cross_variant ( lhs_type, rhs_type) ?;
377+ Some ( ( coerced_type. clone ( ) , coerced_type) )
378+ }
360379 // Unlike with comparison we don't coerce to a decimal in the case of floating point
361380 // numbers, instead falling back to floating point arithmetic instead
362381 (
@@ -953,21 +972,92 @@ pub fn binary_numeric_coercion(
953972pub fn decimal_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
954973 use arrow:: datatypes:: DataType :: * ;
955974
975+ // Prefer decimal data type over floating point for comparison operation
956976 match ( lhs_type, rhs_type) {
957- // Prefer decimal data type over floating point for comparison operation
958- ( Decimal128 ( _, _) , Decimal128 ( _, _) ) => {
977+ // Same decimal types
978+ ( lhs_type, rhs_type)
979+ if is_decimal ( lhs_type)
980+ && is_decimal ( rhs_type)
981+ && std:: mem:: discriminant ( lhs_type)
982+ == std:: mem:: discriminant ( rhs_type) =>
983+ {
959984 get_wider_decimal_type ( lhs_type, rhs_type)
960985 }
961- ( Decimal128 ( _, _) , _) => get_common_decimal_type ( lhs_type, rhs_type) ,
962- ( _, Decimal128 ( _, _) ) => get_common_decimal_type ( rhs_type, lhs_type) ,
963- ( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
964- get_wider_decimal_type ( lhs_type, rhs_type)
986+ // Mismatched decimal types
987+ ( lhs_type, rhs_type)
988+ if is_decimal ( lhs_type)
989+ && is_decimal ( rhs_type)
990+ && std:: mem:: discriminant ( lhs_type)
991+ != std:: mem:: discriminant ( rhs_type) =>
992+ {
993+ get_wider_decimal_type_cross_variant ( lhs_type, rhs_type)
994+ }
995+ // Decimal + non-decimal types
996+ ( Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) | Decimal256 ( _, _) , _) => {
997+ get_common_decimal_type ( lhs_type, rhs_type)
998+ }
999+ ( _, Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) | Decimal256 ( _, _) ) => {
1000+ get_common_decimal_type ( rhs_type, lhs_type)
9651001 }
966- ( Decimal256 ( _, _) , _) => get_common_decimal_type ( lhs_type, rhs_type) ,
967- ( _, Decimal256 ( _, _) ) => get_common_decimal_type ( rhs_type, lhs_type) ,
9681002 ( _, _) => None ,
9691003 }
9701004}
1005+ /// Handle cross-variant decimal widening by choosing the larger variant
1006+ fn get_wider_decimal_type_cross_variant (
1007+ lhs_type : & DataType ,
1008+ rhs_type : & DataType ,
1009+ ) -> Option < DataType > {
1010+ use arrow:: datatypes:: DataType :: * ;
1011+
1012+ let ( p1, s1) = match lhs_type {
1013+ Decimal32 ( p, s) => ( * p, * s) ,
1014+ Decimal64 ( p, s) => ( * p, * s) ,
1015+ Decimal128 ( p, s) => ( * p, * s) ,
1016+ Decimal256 ( p, s) => ( * p, * s) ,
1017+ _ => return None ,
1018+ } ;
1019+
1020+ let ( p2, s2) = match rhs_type {
1021+ Decimal32 ( p, s) => ( * p, * s) ,
1022+ Decimal64 ( p, s) => ( * p, * s) ,
1023+ Decimal128 ( p, s) => ( * p, * s) ,
1024+ Decimal256 ( p, s) => ( * p, * s) ,
1025+ _ => return None ,
1026+ } ;
1027+
1028+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
1029+ let s = s1. max ( s2) ;
1030+ let range = ( p1 as i8 - s1) . max ( p2 as i8 - s2) ;
1031+ let required_precision = ( range + s) as u8 ;
1032+
1033+ // Choose the larger variant between the two input types, while making sure we don't overflow the precision.
1034+ match ( lhs_type, rhs_type) {
1035+ ( Decimal32 ( _, _) , Decimal64 ( _, _) ) | ( Decimal64 ( _, _) , Decimal32 ( _, _) )
1036+ if required_precision <= DECIMAL64_MAX_PRECISION =>
1037+ {
1038+ Some ( Decimal64 ( required_precision, s) )
1039+ }
1040+ ( Decimal32 ( _, _) , Decimal128 ( _, _) )
1041+ | ( Decimal128 ( _, _) , Decimal32 ( _, _) )
1042+ | ( Decimal64 ( _, _) , Decimal128 ( _, _) )
1043+ | ( Decimal128 ( _, _) , Decimal64 ( _, _) )
1044+ if required_precision <= DECIMAL128_MAX_PRECISION =>
1045+ {
1046+ Some ( Decimal128 ( required_precision, s) )
1047+ }
1048+ ( Decimal32 ( _, _) , Decimal256 ( _, _) )
1049+ | ( Decimal256 ( _, _) , Decimal32 ( _, _) )
1050+ | ( Decimal64 ( _, _) , Decimal256 ( _, _) )
1051+ | ( Decimal256 ( _, _) , Decimal64 ( _, _) )
1052+ | ( Decimal128 ( _, _) , Decimal256 ( _, _) )
1053+ | ( Decimal256 ( _, _) , Decimal128 ( _, _) )
1054+ if required_precision <= DECIMAL256_MAX_PRECISION =>
1055+ {
1056+ Some ( Decimal256 ( required_precision, s) )
1057+ }
1058+ _ => None ,
1059+ }
1060+ }
9711061
9721062/// Coerce `lhs_type` and `rhs_type` to a common type.
9731063fn get_common_decimal_type (
@@ -976,7 +1066,15 @@ fn get_common_decimal_type(
9761066) -> Option < DataType > {
9771067 use arrow:: datatypes:: DataType :: * ;
9781068 match decimal_type {
979- Decimal32 ( _, _) | Decimal64 ( _, _) | Decimal128 ( _, _) => {
1069+ Decimal32 ( _, _) => {
1070+ let other_decimal_type = coerce_numeric_type_to_decimal32 ( other_type) ?;
1071+ get_wider_decimal_type ( decimal_type, & other_decimal_type)
1072+ }
1073+ Decimal64 ( _, _) => {
1074+ let other_decimal_type = coerce_numeric_type_to_decimal64 ( other_type) ?;
1075+ get_wider_decimal_type ( decimal_type, & other_decimal_type)
1076+ }
1077+ Decimal128 ( _, _) => {
9801078 let other_decimal_type = coerce_numeric_type_to_decimal128 ( other_type) ?;
9811079 get_wider_decimal_type ( decimal_type, & other_decimal_type)
9821080 }
@@ -988,7 +1086,7 @@ fn get_common_decimal_type(
9881086 }
9891087}
9901088
991- /// Returns a `DataType::Decimal128` that can store any value from either
1089+ /// Returns a decimal [ `DataType`] variant that can store any value from either
9921090/// `lhs_decimal_type` and `rhs_decimal_type`
9931091///
9941092/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
@@ -1209,14 +1307,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
12091307}
12101308
12111309fn create_decimal32_type ( precision : u8 , scale : i8 ) -> DataType {
1212- DataType :: Decimal128 (
1310+ DataType :: Decimal32 (
12131311 DECIMAL32_MAX_PRECISION . min ( precision) ,
12141312 DECIMAL32_MAX_SCALE . min ( scale) ,
12151313 )
12161314}
12171315
12181316fn create_decimal64_type ( precision : u8 , scale : i8 ) -> DataType {
1219- DataType :: Decimal128 (
1317+ DataType :: Decimal64 (
12201318 DECIMAL64_MAX_PRECISION . min ( precision) ,
12211319 DECIMAL64_MAX_SCALE . min ( scale) ,
12221320 )
0 commit comments