Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ pub fn comparison_coercion_numeric(
return Some(lhs_type.clone());
}
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_comparison_coercion_numeric(lhs_type, rhs_type, true))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type))
Expand Down Expand Up @@ -1146,38 +1147,75 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) ->
}
}

/// Coercion rules for Dictionaries: the type that both lhs and rhs
/// Generic coercion rules for Dictionaries: the type that both lhs and rhs
/// can be casted to for the purpose of a computation.
///
/// Not all operators support dictionaries, if `preserve_dictionaries` is true
/// dictionaries will be preserved if possible
fn dictionary_comparison_coercion(
/// dictionaries will be preserved if possible.
///
/// The `coerce_fn` parameter determines which comparison coercion function to use
/// for comparing the dictionary value types.
fn dictionary_comparison_coercion_generic(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
coerce_fn: fn(&DataType, &DataType) -> Option<DataType>,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(
Dictionary(_lhs_index_type, lhs_value_type),
Dictionary(_rhs_index_type, rhs_value_type),
) => comparison_coercion(lhs_value_type, rhs_value_type),
) => coerce_fn(lhs_value_type, rhs_value_type),
(d @ Dictionary(_, value_type), other_type)
| (other_type, d @ Dictionary(_, value_type))
if preserve_dictionaries && value_type.as_ref() == other_type =>
{
Some(d.clone())
}
(Dictionary(_index_type, value_type), _) => {
comparison_coercion(value_type, rhs_type)
}
(_, Dictionary(_index_type, value_type)) => {
comparison_coercion(lhs_type, value_type)
}
(Dictionary(_index_type, value_type), _) => coerce_fn(value_type, rhs_type),
(_, Dictionary(_index_type, value_type)) => coerce_fn(lhs_type, value_type),
_ => None,
}
}

/// Coercion rules for Dictionaries: the type that both lhs and rhs
/// can be casted to for the purpose of a computation.
///
/// Not all operators support dictionaries, if `preserve_dictionaries` is true
/// dictionaries will be preserved if possible
fn dictionary_comparison_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
dictionary_comparison_coercion_generic(
lhs_type,
rhs_type,
preserve_dictionaries,
comparison_coercion,
)
}

/// Coercion rules for Dictionaries with numeric preference: similar to
/// [`dictionary_comparison_coercion`] but uses [`comparison_coercion_numeric`]
/// which prefers numeric types over strings when both are present.
///
/// This is used by [`comparison_coercion_numeric`] to maintain consistent
/// numeric-preferring semantics when dealing with dictionary types.
fn dictionary_comparison_coercion_numeric(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
dictionary_comparison_coercion_generic(
lhs_type,
rhs_type,
preserve_dictionaries,
comparison_coercion_numeric,
)
}

/// Coercion rules for string concat.
/// This is a union of string coercion rules and specified rules:
/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
Expand Down
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/nullif.slt
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,39 @@ query T
SELECT NULLIF(arrow_cast('a', 'Utf8View'), null);
----
a

# Test with dictionary-encoded strings
# This tests the fix for: "Dictionary(UInt32, Utf8) and Utf8 is not comparable"
statement ok
CREATE TABLE dict_test_base(
col1 TEXT,
col2 TEXT
) as VALUES
('foo', 'bar'),
('bar', 'bar'),
('baz', 'bar')
;

# Dictionary cast with string literal
query T rowsort
SELECT NULLIF(arrow_cast(col1, 'Dictionary(Int32, Utf8)'), 'bar') FROM dict_test_base;
----
NULL
baz
foo

# String with dictionary cast
query T rowsort
SELECT NULLIF(col2, arrow_cast(col1, 'Dictionary(Int32, Utf8)')) FROM dict_test_base;
----
NULL
bar
bar

# Both as dictionaries
query T rowsort
SELECT NULLIF(arrow_cast(col1, 'Dictionary(Int32, Utf8)'), arrow_cast('bar', 'Dictionary(Int32, Utf8)')) FROM dict_test_base;
----
NULL
baz
foo
Loading