1717
1818//! Coercion rules for matching argument types for binary operators
1919
20+ use std:: collections:: HashSet ;
2021use std:: sync:: Arc ;
2122
2223use crate :: Operator ;
@@ -289,13 +290,207 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataT
289290 }
290291}
291292
293+ #[ derive( Debug , PartialEq , Eq , Hash , Clone ) ]
294+ enum TypeCategory {
295+ Array ,
296+ Boolean ,
297+ Numeric ,
298+ // String, well-defined type, but are considered as unknown type.
299+ DateTime ,
300+ Composite ,
301+ Unknown ,
302+ NotSupported ,
303+ }
304+
305+ impl From < & DataType > for TypeCategory {
306+ fn from ( data_type : & DataType ) -> Self {
307+ match data_type {
308+ // Dict is a special type in arrow, we check the value type
309+ DataType :: Dictionary ( _, v) => {
310+ let v = v. as_ref ( ) ;
311+ TypeCategory :: from ( v)
312+ }
313+ _ => {
314+ if data_type. is_numeric ( ) {
315+ return TypeCategory :: Numeric ;
316+ }
317+
318+ if matches ! ( data_type, DataType :: Boolean ) {
319+ return TypeCategory :: Boolean ;
320+ }
321+
322+ if matches ! (
323+ data_type,
324+ DataType :: List ( _)
325+ | DataType :: FixedSizeList ( _, _)
326+ | DataType :: LargeList ( _)
327+ ) {
328+ return TypeCategory :: Array ;
329+ }
330+
331+ // String literal is possible to cast to many other types like numeric or datetime,
332+ // therefore, it is categorized as a unknown type
333+ if matches ! (
334+ data_type,
335+ DataType :: Utf8 | DataType :: LargeUtf8 | DataType :: Null
336+ ) {
337+ return TypeCategory :: Unknown ;
338+ }
339+
340+ if matches ! (
341+ data_type,
342+ DataType :: Date32
343+ | DataType :: Date64
344+ | DataType :: Time32 ( _)
345+ | DataType :: Time64 ( _)
346+ | DataType :: Timestamp ( _, _)
347+ | DataType :: Interval ( _)
348+ | DataType :: Duration ( _)
349+ ) {
350+ return TypeCategory :: DateTime ;
351+ }
352+
353+ if matches ! (
354+ data_type,
355+ DataType :: Map ( _, _) | DataType :: Struct ( _) | DataType :: Union ( _, _)
356+ ) {
357+ return TypeCategory :: Composite ;
358+ }
359+
360+ TypeCategory :: NotSupported
361+ }
362+ }
363+ }
364+ }
365+
366+ /// Coerce dissimilar data types to a single data type.
367+ /// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are
368+ /// examples that has the similar resolution rules.
369+ /// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information.
370+ /// The rules in the document provide a clue, but adhering strictly to them doesn't precisely
371+ /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules
372+ /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted
373+ /// decimal percision and scale when coercing decimal types.
374+ pub fn type_union_resolution ( data_types : & [ DataType ] ) -> Option < DataType > {
375+ if data_types. is_empty ( ) {
376+ return None ;
377+ }
378+
379+ // if all the data_types is the same return first one
380+ if data_types. iter ( ) . all ( |t| t == & data_types[ 0 ] ) {
381+ return Some ( data_types[ 0 ] . clone ( ) ) ;
382+ }
383+
384+ // if all the data_types are null, return string
385+ if data_types. iter ( ) . all ( |t| t == & DataType :: Null ) {
386+ return Some ( DataType :: Utf8 ) ;
387+ }
388+
389+ // Ignore Nulls, if any data_type category is not the same, return None
390+ let data_types_category: Vec < TypeCategory > = data_types
391+ . iter ( )
392+ . filter ( |& t| t != & DataType :: Null )
393+ . map ( |t| t. into ( ) )
394+ . collect ( ) ;
395+
396+ if data_types_category
397+ . iter ( )
398+ . any ( |t| t == & TypeCategory :: NotSupported )
399+ {
400+ return None ;
401+ }
402+
403+ // check if there is only one category excluding Unknown
404+ let categories: HashSet < TypeCategory > = HashSet :: from_iter (
405+ data_types_category
406+ . iter ( )
407+ . filter ( |& c| c != & TypeCategory :: Unknown )
408+ . cloned ( ) ,
409+ ) ;
410+ if categories. len ( ) > 1 {
411+ return None ;
412+ }
413+
414+ // Ignore Nulls
415+ let mut candidate_type: Option < DataType > = None ;
416+ for data_type in data_types. iter ( ) {
417+ if data_type == & DataType :: Null {
418+ continue ;
419+ }
420+ if let Some ( ref candidate_t) = candidate_type {
421+ // Find candidate type that all the data types can be coerced to
422+ // Follows the behavior of Postgres and DuckDB
423+ // Coerced type may be different from the candidate and current data type
424+ // For example,
425+ // i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2)
426+ // numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2)
427+ if let Some ( t) = type_union_resolution_coercion ( data_type, candidate_t) {
428+ candidate_type = Some ( t) ;
429+ } else {
430+ return None ;
431+ }
432+ } else {
433+ candidate_type = Some ( data_type. clone ( ) ) ;
434+ }
435+ }
436+
437+ candidate_type
438+ }
439+
440+ /// Coerce `lhs_type` and `rhs_type` to a common type for [type_union_resolution]
441+ /// See [type_union_resolution] for more information.
442+ fn type_union_resolution_coercion (
443+ lhs_type : & DataType ,
444+ rhs_type : & DataType ,
445+ ) -> Option < DataType > {
446+ if lhs_type == rhs_type {
447+ return Some ( lhs_type. clone ( ) ) ;
448+ }
449+
450+ match ( lhs_type, rhs_type) {
451+ (
452+ DataType :: Dictionary ( lhs_index_type, lhs_value_type) ,
453+ DataType :: Dictionary ( rhs_index_type, rhs_value_type) ,
454+ ) => {
455+ let new_index_type =
456+ type_union_resolution_coercion ( lhs_index_type, rhs_index_type) ;
457+ let new_value_type =
458+ type_union_resolution_coercion ( lhs_value_type, rhs_value_type) ;
459+ if let ( Some ( new_index_type) , Some ( new_value_type) ) =
460+ ( new_index_type, new_value_type)
461+ {
462+ Some ( DataType :: Dictionary (
463+ Box :: new ( new_index_type) ,
464+ Box :: new ( new_value_type) ,
465+ ) )
466+ } else {
467+ None
468+ }
469+ }
470+ ( DataType :: Dictionary ( index_type, value_type) , other_type)
471+ | ( other_type, DataType :: Dictionary ( index_type, value_type) ) => {
472+ let new_value_type = type_union_resolution_coercion ( value_type, other_type) ;
473+ new_value_type. map ( |t| DataType :: Dictionary ( index_type. clone ( ) , Box :: new ( t) ) )
474+ }
475+ _ => {
476+ // numeric coercion is the same as comparison coercion, both find the narrowest type
477+ // that can accommodate both types
478+ binary_numeric_coercion ( lhs_type, rhs_type)
479+ . or_else ( || string_coercion ( lhs_type, rhs_type) )
480+ . or_else ( || numeric_string_coercion ( lhs_type, rhs_type) )
481+ }
482+ }
483+ }
484+
292485/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
486+ /// Unlike `coerced_from`, usually the coerced type is for comparison only.
487+ /// For example, compare with Dictionary and Dictionary, only value type is what we care about
293488pub fn comparison_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
294489 if lhs_type == rhs_type {
295490 // same type => equality is possible
296491 return Some ( lhs_type. clone ( ) ) ;
297492 }
298- comparison_binary_numeric_coercion ( lhs_type, rhs_type)
493+ binary_numeric_coercion ( lhs_type, rhs_type)
299494 . or_else ( || dictionary_coercion ( lhs_type, rhs_type, true ) )
300495 . or_else ( || temporal_coercion ( lhs_type, rhs_type) )
301496 . or_else ( || string_coercion ( lhs_type, rhs_type) )
@@ -312,7 +507,7 @@ pub fn values_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataT
312507 // same type => equality is possible
313508 return Some ( lhs_type. clone ( ) ) ;
314509 }
315- comparison_binary_numeric_coercion ( lhs_type, rhs_type)
510+ binary_numeric_coercion ( lhs_type, rhs_type)
316511 . or_else ( || temporal_coercion ( lhs_type, rhs_type) )
317512 . or_else ( || string_coercion ( lhs_type, rhs_type) )
318513 . or_else ( || binary_coercion ( lhs_type, rhs_type) )
@@ -372,9 +567,8 @@ fn string_temporal_coercion(
372567 match_rule ( lhs_type, rhs_type) . or_else ( || match_rule ( rhs_type, lhs_type) )
373568}
374569
375- /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
376- /// where one both are numeric
377- pub ( crate ) fn comparison_binary_numeric_coercion (
570+ /// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric
571+ pub ( crate ) fn binary_numeric_coercion (
378572 lhs_type : & DataType ,
379573 rhs_type : & DataType ,
380574) -> Option < DataType > {
@@ -388,27 +582,25 @@ pub(crate) fn comparison_binary_numeric_coercion(
388582 return Some ( lhs_type. clone ( ) ) ;
389583 }
390584
585+ if let Some ( t) = decimal_coercion ( lhs_type, rhs_type) {
586+ return Some ( t) ;
587+ }
588+
391589 // these are ordered from most informative to least informative so
392590 // that the coercion does not lose information via truncation
393591 match ( lhs_type, rhs_type) {
394- // Prefer decimal data type over floating point for comparison operation
395- ( Decimal128 ( _, _) , Decimal128 ( _, _) ) => {
396- get_wider_decimal_type ( lhs_type, rhs_type)
397- }
398- ( Decimal128 ( _, _) , _) => get_comparison_common_decimal_type ( lhs_type, rhs_type) ,
399- ( _, Decimal128 ( _, _) ) => get_comparison_common_decimal_type ( rhs_type, lhs_type) ,
400- ( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
401- get_wider_decimal_type ( lhs_type, rhs_type)
402- }
403- ( Decimal256 ( _, _) , _) => get_comparison_common_decimal_type ( lhs_type, rhs_type) ,
404- ( _, Decimal256 ( _, _) ) => get_comparison_common_decimal_type ( rhs_type, lhs_type) ,
405592 ( Float64 , _) | ( _, Float64 ) => Some ( Float64 ) ,
406593 ( _, Float32 ) | ( Float32 , _) => Some ( Float32 ) ,
407594 // The following match arms encode the following logic: Given the two
408595 // integral types, we choose the narrowest possible integral type that
409596 // accommodates all values of both types. Note that some information
410597 // loss is inevitable when we have a signed type and a `UInt64`, in
411598 // which case we use `Int64`;i.e. the widest signed integral type.
599+
600+ // TODO: For i64 and u64, we can use decimal or float64
601+ // Postgres has no unsigned type :(
602+ // DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes))
603+ // for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer)
412604 ( Int64 , _)
413605 | ( _, Int64 )
414606 | ( UInt64 , Int8 )
@@ -439,9 +631,28 @@ pub(crate) fn comparison_binary_numeric_coercion(
439631 }
440632}
441633
442- /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
443- /// a comparison operation where one is a decimal
444- fn get_comparison_common_decimal_type (
634+ /// Decimal coercion rules.
635+ pub fn decimal_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
636+ use arrow:: datatypes:: DataType :: * ;
637+
638+ match ( lhs_type, rhs_type) {
639+ // Prefer decimal data type over floating point for comparison operation
640+ ( Decimal128 ( _, _) , Decimal128 ( _, _) ) => {
641+ get_wider_decimal_type ( lhs_type, rhs_type)
642+ }
643+ ( Decimal128 ( _, _) , _) => get_common_decimal_type ( lhs_type, rhs_type) ,
644+ ( _, Decimal128 ( _, _) ) => get_common_decimal_type ( rhs_type, lhs_type) ,
645+ ( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
646+ get_wider_decimal_type ( lhs_type, rhs_type)
647+ }
648+ ( Decimal256 ( _, _) , _) => get_common_decimal_type ( lhs_type, rhs_type) ,
649+ ( _, Decimal256 ( _, _) ) => get_common_decimal_type ( rhs_type, lhs_type) ,
650+ ( _, _) => None ,
651+ }
652+ }
653+
654+ /// Coerce `lhs_type` and `rhs_type` to a common type.
655+ fn get_common_decimal_type (
445656 decimal_type : & DataType ,
446657 other_type : & DataType ,
447658) -> Option < DataType > {
@@ -725,6 +936,18 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
725936 }
726937}
727938
939+ fn numeric_string_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
940+ use arrow:: datatypes:: DataType :: * ;
941+ match ( lhs_type, rhs_type) {
942+ ( Utf8 | LargeUtf8 , other_type) | ( other_type, Utf8 | LargeUtf8 )
943+ if other_type. is_numeric ( ) =>
944+ {
945+ Some ( other_type. clone ( ) )
946+ }
947+ _ => None ,
948+ }
949+ }
950+
728951/// Coercion rules for list types.
729952fn list_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
730953 use arrow:: datatypes:: DataType :: * ;
0 commit comments