Skip to content

Commit 0836500

Browse files
authored
Fix Coalesce casting logic to follows what Postgres and DuckDB do. Introduce signature that do non-comparison coercion (#10268)
* remove casting for coalesce Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * crate only visibility Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * polish comment Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * improve test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * backup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * introduce new signautre for coalesce Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * ignore err msg Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fmt Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix doc Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * switch to type_resolution coercion Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix i64 and u64 case Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more tests Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add null case Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fmt Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * rename to type_union_resolution Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add comment Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add comment Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * rm test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup since rebase Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix msg Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fmt Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * rm pure_string_coercion Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * rm duplicate Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * change type in select.slt Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix slt Signed-off-by: jayzhan211 <jayzhan211@gmail.com> --------- Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
1 parent 26b44f4 commit 0836500

File tree

5 files changed

+378
-91
lines changed

5 files changed

+378
-91
lines changed

datafusion/expr/src/type_coercion/binary.rs

Lines changed: 242 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Coercion rules for matching argument types for binary operators
1919
20+
use std::collections::HashSet;
2021
use std::sync::Arc;
2122

2223
use crate::Operator;
@@ -289,13 +290,207 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataT
289290
}
290291
}
291292

293+
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
294+
enum TypeCategory {
295+
Array,
296+
Boolean,
297+
Numeric,
298+
// String, well-defined type, but are considered as unknown type.
299+
DateTime,
300+
Composite,
301+
Unknown,
302+
NotSupported,
303+
}
304+
305+
impl From<&DataType> for TypeCategory {
306+
fn from(data_type: &DataType) -> Self {
307+
match data_type {
308+
// Dict is a special type in arrow, we check the value type
309+
DataType::Dictionary(_, v) => {
310+
let v = v.as_ref();
311+
TypeCategory::from(v)
312+
}
313+
_ => {
314+
if data_type.is_numeric() {
315+
return TypeCategory::Numeric;
316+
}
317+
318+
if matches!(data_type, DataType::Boolean) {
319+
return TypeCategory::Boolean;
320+
}
321+
322+
if matches!(
323+
data_type,
324+
DataType::List(_)
325+
| DataType::FixedSizeList(_, _)
326+
| DataType::LargeList(_)
327+
) {
328+
return TypeCategory::Array;
329+
}
330+
331+
// String literal is possible to cast to many other types like numeric or datetime,
332+
// therefore, it is categorized as a unknown type
333+
if matches!(
334+
data_type,
335+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
336+
) {
337+
return TypeCategory::Unknown;
338+
}
339+
340+
if matches!(
341+
data_type,
342+
DataType::Date32
343+
| DataType::Date64
344+
| DataType::Time32(_)
345+
| DataType::Time64(_)
346+
| DataType::Timestamp(_, _)
347+
| DataType::Interval(_)
348+
| DataType::Duration(_)
349+
) {
350+
return TypeCategory::DateTime;
351+
}
352+
353+
if matches!(
354+
data_type,
355+
DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _)
356+
) {
357+
return TypeCategory::Composite;
358+
}
359+
360+
TypeCategory::NotSupported
361+
}
362+
}
363+
}
364+
}
365+
366+
/// Coerce dissimilar data types to a single data type.
367+
/// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are
368+
/// examples that has the similar resolution rules.
369+
/// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information.
370+
/// The rules in the document provide a clue, but adhering strictly to them doesn't precisely
371+
/// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules
372+
/// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted
373+
/// decimal percision and scale when coercing decimal types.
374+
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
375+
if data_types.is_empty() {
376+
return None;
377+
}
378+
379+
// if all the data_types is the same return first one
380+
if data_types.iter().all(|t| t == &data_types[0]) {
381+
return Some(data_types[0].clone());
382+
}
383+
384+
// if all the data_types are null, return string
385+
if data_types.iter().all(|t| t == &DataType::Null) {
386+
return Some(DataType::Utf8);
387+
}
388+
389+
// Ignore Nulls, if any data_type category is not the same, return None
390+
let data_types_category: Vec<TypeCategory> = data_types
391+
.iter()
392+
.filter(|&t| t != &DataType::Null)
393+
.map(|t| t.into())
394+
.collect();
395+
396+
if data_types_category
397+
.iter()
398+
.any(|t| t == &TypeCategory::NotSupported)
399+
{
400+
return None;
401+
}
402+
403+
// check if there is only one category excluding Unknown
404+
let categories: HashSet<TypeCategory> = HashSet::from_iter(
405+
data_types_category
406+
.iter()
407+
.filter(|&c| c != &TypeCategory::Unknown)
408+
.cloned(),
409+
);
410+
if categories.len() > 1 {
411+
return None;
412+
}
413+
414+
// Ignore Nulls
415+
let mut candidate_type: Option<DataType> = None;
416+
for data_type in data_types.iter() {
417+
if data_type == &DataType::Null {
418+
continue;
419+
}
420+
if let Some(ref candidate_t) = candidate_type {
421+
// Find candidate type that all the data types can be coerced to
422+
// Follows the behavior of Postgres and DuckDB
423+
// Coerced type may be different from the candidate and current data type
424+
// For example,
425+
// i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2)
426+
// numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2)
427+
if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) {
428+
candidate_type = Some(t);
429+
} else {
430+
return None;
431+
}
432+
} else {
433+
candidate_type = Some(data_type.clone());
434+
}
435+
}
436+
437+
candidate_type
438+
}
439+
440+
/// Coerce `lhs_type` and `rhs_type` to a common type for [type_union_resolution]
441+
/// See [type_union_resolution] for more information.
442+
fn type_union_resolution_coercion(
443+
lhs_type: &DataType,
444+
rhs_type: &DataType,
445+
) -> Option<DataType> {
446+
if lhs_type == rhs_type {
447+
return Some(lhs_type.clone());
448+
}
449+
450+
match (lhs_type, rhs_type) {
451+
(
452+
DataType::Dictionary(lhs_index_type, lhs_value_type),
453+
DataType::Dictionary(rhs_index_type, rhs_value_type),
454+
) => {
455+
let new_index_type =
456+
type_union_resolution_coercion(lhs_index_type, rhs_index_type);
457+
let new_value_type =
458+
type_union_resolution_coercion(lhs_value_type, rhs_value_type);
459+
if let (Some(new_index_type), Some(new_value_type)) =
460+
(new_index_type, new_value_type)
461+
{
462+
Some(DataType::Dictionary(
463+
Box::new(new_index_type),
464+
Box::new(new_value_type),
465+
))
466+
} else {
467+
None
468+
}
469+
}
470+
(DataType::Dictionary(index_type, value_type), other_type)
471+
| (other_type, DataType::Dictionary(index_type, value_type)) => {
472+
let new_value_type = type_union_resolution_coercion(value_type, other_type);
473+
new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t)))
474+
}
475+
_ => {
476+
// numeric coercion is the same as comparison coercion, both find the narrowest type
477+
// that can accommodate both types
478+
binary_numeric_coercion(lhs_type, rhs_type)
479+
.or_else(|| string_coercion(lhs_type, rhs_type))
480+
.or_else(|| numeric_string_coercion(lhs_type, rhs_type))
481+
}
482+
}
483+
}
484+
292485
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
486+
/// Unlike `coerced_from`, usually the coerced type is for comparison only.
487+
/// For example, compare with Dictionary and Dictionary, only value type is what we care about
293488
pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
294489
if lhs_type == rhs_type {
295490
// same type => equality is possible
296491
return Some(lhs_type.clone());
297492
}
298-
comparison_binary_numeric_coercion(lhs_type, rhs_type)
493+
binary_numeric_coercion(lhs_type, rhs_type)
299494
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
300495
.or_else(|| temporal_coercion(lhs_type, rhs_type))
301496
.or_else(|| string_coercion(lhs_type, rhs_type))
@@ -312,7 +507,7 @@ pub fn values_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataT
312507
// same type => equality is possible
313508
return Some(lhs_type.clone());
314509
}
315-
comparison_binary_numeric_coercion(lhs_type, rhs_type)
510+
binary_numeric_coercion(lhs_type, rhs_type)
316511
.or_else(|| temporal_coercion(lhs_type, rhs_type))
317512
.or_else(|| string_coercion(lhs_type, rhs_type))
318513
.or_else(|| binary_coercion(lhs_type, rhs_type))
@@ -372,9 +567,8 @@ fn string_temporal_coercion(
372567
match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type))
373568
}
374569

375-
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
376-
/// where one both are numeric
377-
pub(crate) fn comparison_binary_numeric_coercion(
570+
/// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric
571+
pub(crate) fn binary_numeric_coercion(
378572
lhs_type: &DataType,
379573
rhs_type: &DataType,
380574
) -> Option<DataType> {
@@ -388,27 +582,25 @@ pub(crate) fn comparison_binary_numeric_coercion(
388582
return Some(lhs_type.clone());
389583
}
390584

585+
if let Some(t) = decimal_coercion(lhs_type, rhs_type) {
586+
return Some(t);
587+
}
588+
391589
// these are ordered from most informative to least informative so
392590
// that the coercion does not lose information via truncation
393591
match (lhs_type, rhs_type) {
394-
// Prefer decimal data type over floating point for comparison operation
395-
(Decimal128(_, _), Decimal128(_, _)) => {
396-
get_wider_decimal_type(lhs_type, rhs_type)
397-
}
398-
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
399-
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
400-
(Decimal256(_, _), Decimal256(_, _)) => {
401-
get_wider_decimal_type(lhs_type, rhs_type)
402-
}
403-
(Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
404-
(_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
405592
(Float64, _) | (_, Float64) => Some(Float64),
406593
(_, Float32) | (Float32, _) => Some(Float32),
407594
// The following match arms encode the following logic: Given the two
408595
// integral types, we choose the narrowest possible integral type that
409596
// accommodates all values of both types. Note that some information
410597
// loss is inevitable when we have a signed type and a `UInt64`, in
411598
// which case we use `Int64`;i.e. the widest signed integral type.
599+
600+
// TODO: For i64 and u64, we can use decimal or float64
601+
// Postgres has no unsigned type :(
602+
// DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes))
603+
// for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer)
412604
(Int64, _)
413605
| (_, Int64)
414606
| (UInt64, Int8)
@@ -439,9 +631,28 @@ pub(crate) fn comparison_binary_numeric_coercion(
439631
}
440632
}
441633

442-
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
443-
/// a comparison operation where one is a decimal
444-
fn get_comparison_common_decimal_type(
634+
/// Decimal coercion rules.
635+
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
636+
use arrow::datatypes::DataType::*;
637+
638+
match (lhs_type, rhs_type) {
639+
// Prefer decimal data type over floating point for comparison operation
640+
(Decimal128(_, _), Decimal128(_, _)) => {
641+
get_wider_decimal_type(lhs_type, rhs_type)
642+
}
643+
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
644+
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
645+
(Decimal256(_, _), Decimal256(_, _)) => {
646+
get_wider_decimal_type(lhs_type, rhs_type)
647+
}
648+
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
649+
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
650+
(_, _) => None,
651+
}
652+
}
653+
654+
/// Coerce `lhs_type` and `rhs_type` to a common type.
655+
fn get_common_decimal_type(
445656
decimal_type: &DataType,
446657
other_type: &DataType,
447658
) -> Option<DataType> {
@@ -725,6 +936,18 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
725936
}
726937
}
727938

939+
fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
940+
use arrow::datatypes::DataType::*;
941+
match (lhs_type, rhs_type) {
942+
(Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8)
943+
if other_type.is_numeric() =>
944+
{
945+
Some(other_type.clone())
946+
}
947+
_ => None,
948+
}
949+
}
950+
728951
/// Coercion rules for list types.
729952
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
730953
use arrow::datatypes::DataType::*;

0 commit comments

Comments
 (0)