Skip to content

Commit 5fe5bf3

Browse files
committed
Saner handling of nulls inside arrays
1 parent 6e422e0 commit 5fe5bf3

File tree

12 files changed

+481
-425
lines changed

12 files changed

+481
-425
lines changed

datafusion/expr-common/src/signature.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,13 +843,15 @@ impl Signature {
843843
volatility,
844844
}
845845
}
846+
846847
/// Any one of a list of [TypeSignature]s.
847848
pub fn one_of(type_signatures: Vec<TypeSignature>, volatility: Volatility) -> Self {
848849
Signature {
849850
type_signature: TypeSignature::OneOf(type_signatures),
850851
volatility,
851852
}
852853
}
854+
853855
/// Specialized Signature for ArrayAppend and similar functions
854856
pub fn array_and_element(volatility: Volatility) -> Self {
855857
Signature {
@@ -865,6 +867,39 @@ impl Signature {
865867
volatility,
866868
}
867869
}
870+
871+
/// Specialized Signature for ArrayPrepend and similar functions
872+
pub fn element_and_array(volatility: Volatility) -> Self {
873+
Signature {
874+
type_signature: TypeSignature::ArraySignature(
875+
ArrayFunctionSignature::Array {
876+
arguments: vec![
877+
ArrayFunctionArgument::Element,
878+
ArrayFunctionArgument::Array,
879+
],
880+
array_coercion: Some(ListCoercion::FixedSizedListToList),
881+
},
882+
),
883+
volatility,
884+
}
885+
}
886+
887+
/// Specialized Signature for ArrayUnion and similar functions
888+
pub fn array_and_array(volatility: Volatility) -> Self {
889+
Signature {
890+
type_signature: TypeSignature::ArraySignature(
891+
ArrayFunctionSignature::Array {
892+
arguments: vec![
893+
ArrayFunctionArgument::Array,
894+
ArrayFunctionArgument::Array,
895+
],
896+
array_coercion: Some(ListCoercion::FixedSizedListToList),
897+
},
898+
),
899+
volatility,
900+
}
901+
}
902+
868903
/// Specialized Signature for Array functions with an optional index
869904
pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self {
870905
Signature {
@@ -898,7 +933,7 @@ impl Signature {
898933
ArrayFunctionArgument::Array,
899934
ArrayFunctionArgument::Index,
900935
],
901-
array_coercion: None,
936+
array_coercion: Some(ListCoercion::FixedSizedListToList),
902937
},
903938
),
904939
volatility,
@@ -910,7 +945,7 @@ impl Signature {
910945
type_signature: TypeSignature::ArraySignature(
911946
ArrayFunctionSignature::Array {
912947
arguments: vec![ArrayFunctionArgument::Array],
913-
array_coercion: None,
948+
array_coercion: Some(ListCoercion::FixedSizedListToList),
914949
},
915950
),
916951
volatility,

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 66 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,21 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use super::binary::{binary_numeric_coercion, comparison_coercion};
18+
use super::binary::binary_numeric_coercion;
1919
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
2020
use arrow::{
2121
compute::can_cast_types,
22-
datatypes::{DataType, Field, TimeUnit},
22+
datatypes::{DataType, TimeUnit},
2323
};
2424
use datafusion_common::types::LogicalType;
25-
use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
25+
use datafusion_common::utils::{
26+
base_type, coerced_fixed_size_list_to_list, ListCoercion,
27+
};
2628
use datafusion_common::{
27-
exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType,
28-
utils::list_ndims, Result,
29+
exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result,
2930
};
3031
use datafusion_expr_common::signature::ArrayFunctionArgument;
32+
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
3133
use datafusion_expr_common::{
3234
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
3335
type_coercion::binary::comparison_coercion_numeric,
@@ -364,98 +366,73 @@ fn get_valid_types(
364366
return Ok(vec![vec![]]);
365367
}
366368

367-
let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| {
368-
if *arg == ArrayFunctionArgument::Array {
369-
Some(idx)
370-
} else {
371-
None
372-
}
373-
});
374-
let Some(array_idx) = array_idx else {
375-
return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument"));
376-
};
377-
let Some(array_type) = array(&current_types[array_idx]) else {
378-
return Ok(vec![vec![]]);
379-
};
369+
let mut fixed_size = None;
370+
let mut large_list = false;
371+
let mut element_types = Vec::with_capacity(arguments.len());
372+
for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
373+
match argument {
374+
ArrayFunctionArgument::Array => match current_type {
375+
DataType::FixedSizeList(field, size) => {
376+
match array_coercion {
377+
Some(ListCoercion::FixedSizedListToList) => (),
378+
None if fixed_size.is_none() => fixed_size = Some(*size),
379+
None if fixed_size == Some(*size) => (),
380+
None => fixed_size = None,
381+
}
380382

381-
// We need to find the coerced base type, mainly for cases like:
382-
// `array_append(List(null), i64)` -> `List(i64)`
383-
let mut new_base_type = datafusion_common::utils::base_type(&array_type);
384-
for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
385-
match argument_type {
386-
ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => {
387-
new_base_type =
388-
coerce_array_types(function_name, current_type, &new_base_type)?;
383+
element_types.push(field.data_type().clone())
384+
}
385+
DataType::List(field) => {
386+
fixed_size = None;
387+
element_types.push(field.data_type().clone())
388+
}
389+
DataType::LargeList(field) => {
390+
fixed_size = None;
391+
large_list = true;
392+
element_types.push(field.data_type().clone())
393+
}
394+
DataType::Null => {
395+
fixed_size = None;
396+
element_types.push(DataType::Null)
397+
}
398+
arg_type => {
399+
return plan_err!(
400+
"{function_name} does not support an argument of type {arg_type}"
401+
)
402+
}
403+
},
404+
ArrayFunctionArgument::Element => {
405+
element_types.push(current_type.clone())
389406
}
390-
ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {}
407+
ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (),
391408
}
392409
}
393-
let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
394-
&array_type,
395-
&new_base_type,
396-
array_coercion,
397-
);
398410

399-
let new_elem_type = match new_array_type {
400-
DataType::List(ref field)
401-
| DataType::LargeList(ref field)
402-
| DataType::FixedSizeList(ref field, _) => field.data_type(),
403-
_ => return Ok(vec![vec![]]),
411+
let Some(element_type) = type_union_resolution(&element_types) else {
412+
return plan_err!(
413+
"Failed to unify argument types of {function_name}: {current_types:?}."
414+
);
404415
};
405416

406-
let mut valid_types = Vec::with_capacity(arguments.len());
407-
for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
408-
let valid_type = match argument_type {
409-
ArrayFunctionArgument::Element => new_elem_type.clone(),
417+
let array_type = if large_list {
418+
DataType::new_large_list(element_type.clone(), true)
419+
} else if let Some(size) = fixed_size {
420+
DataType::new_fixed_size_list(element_type.clone(), size, true)
421+
} else {
422+
DataType::new_list(element_type.clone(), true)
423+
};
424+
425+
let valid_types = arguments.iter().zip(current_types.iter()).map(
426+
|(argument_type, current_type)| match argument_type {
427+
ArrayFunctionArgument::Array if current_type.is_null() => DataType::Null,
428+
ArrayFunctionArgument::Array => array_type.clone(),
429+
ArrayFunctionArgument::Element => element_type.clone(),
410430
ArrayFunctionArgument::Index => DataType::Int64,
411431
ArrayFunctionArgument::String => DataType::Utf8,
412-
ArrayFunctionArgument::Array => {
413-
let Some(current_type) = array(current_type) else {
414-
return Ok(vec![vec![]]);
415-
};
416-
let new_type =
417-
datafusion_common::utils::coerced_type_with_base_type_only(
418-
&current_type,
419-
&new_base_type,
420-
array_coercion,
421-
);
422-
// All array arguments must be coercible to the same type
423-
if new_type != new_array_type {
424-
return Ok(vec![vec![]]);
425-
}
426-
new_type
427-
}
428-
};
429-
valid_types.push(valid_type);
430-
}
431-
432-
Ok(vec![valid_types])
433-
}
434-
435-
fn array(array_type: &DataType) -> Option<DataType> {
436-
match array_type {
437-
DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
438-
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
439-
DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field(
440-
DataType::Int64,
441-
true,
442-
)))),
443-
_ => None,
444-
}
445-
}
432+
},
433+
);
446434

447-
fn coerce_array_types(
448-
function_name: &str,
449-
current_type: &DataType,
450-
base_type: &DataType,
451-
) -> Result<DataType> {
452-
let current_base_type = datafusion_common::utils::base_type(current_type);
453-
let new_base_type = comparison_coercion(base_type, &current_base_type);
454-
new_base_type.ok_or_else(|| {
455-
internal_datafusion_err!(
456-
"Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}"
457-
)
458-
})
435+
Ok(vec![valid_types.collect()])
459436
}
460437

461438
fn recursive_array(array_type: &DataType) -> Option<DataType> {
@@ -800,7 +777,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
800777
///
801778
/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
802779
///
803-
/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion.
780+
/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion.
804781
fn coerced_from<'a>(
805782
type_into: &'a DataType,
806783
type_from: &'a DataType,
@@ -867,7 +844,7 @@ fn coerced_from<'a>(
867844
// Only accept list and largelist with the same number of dimensions unless the type is Null.
868845
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
869846
(List(_) | LargeList(_), _)
870-
if datafusion_common::utils::base_type(type_from).eq(&Null)
847+
if base_type(type_from).is_null()
871848
|| list_ndims(type_from) == list_ndims(type_into) =>
872849
{
873850
Some(type_into.clone())
@@ -906,7 +883,6 @@ fn coerced_from<'a>(
906883

907884
#[cfg(test)]
908885
mod tests {
909-
910886
use crate::Volatility;
911887

912888
use super::*;

datafusion/functions-nested/src/cardinality.rs

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ use arrow::array::{
2323
};
2424
use arrow::datatypes::{
2525
DataType,
26-
DataType::{FixedSizeList, LargeList, List, Map, UInt64},
26+
DataType::{LargeList, List, Map, Null, UInt64},
2727
};
2828
use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array};
29-
use datafusion_common::utils::take_function_args;
29+
use datafusion_common::exec_err;
30+
use datafusion_common::utils::{take_function_args, ListCoercion};
3031
use datafusion_common::Result;
31-
use datafusion_common::{exec_err, plan_err};
3232
use datafusion_expr::{
3333
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
3434
ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -52,7 +52,7 @@ impl Cardinality {
5252
vec![
5353
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
5454
arguments: vec![ArrayFunctionArgument::Array],
55-
array_coercion: None,
55+
array_coercion: Some(ListCoercion::FixedSizedListToList),
5656
}),
5757
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
5858
],
@@ -103,13 +103,8 @@ impl ScalarUDFImpl for Cardinality {
103103
&self.signature
104104
}
105105

106-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
107-
Ok(match arg_types[0] {
108-
List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64,
109-
_ => {
110-
return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map.");
111-
}
112-
})
106+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
107+
Ok(UInt64)
113108
}
114109

115110
fn invoke_with_args(
@@ -131,21 +126,22 @@ impl ScalarUDFImpl for Cardinality {
131126
/// Cardinality SQL function
132127
pub fn cardinality_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
133128
let [array] = take_function_args("cardinality", args)?;
134-
match &array.data_type() {
129+
match array.data_type() {
130+
Null => Ok(Arc::new(UInt64Array::from_value(0, array.len()))),
135131
List(_) => {
136-
let list_array = as_list_array(&array)?;
132+
let list_array = as_list_array(array)?;
137133
generic_list_cardinality::<i32>(list_array)
138134
}
139135
LargeList(_) => {
140-
let list_array = as_large_list_array(&array)?;
136+
let list_array = as_large_list_array(array)?;
141137
generic_list_cardinality::<i64>(list_array)
142138
}
143139
Map(_, _) => {
144-
let map_array = as_map_array(&array)?;
140+
let map_array = as_map_array(array)?;
145141
generic_map_cardinality(map_array)
146142
}
147-
other => {
148-
exec_err!("cardinality does not support type '{:?}'", other)
143+
arg_type => {
144+
exec_err!("cardinality does not support an argument of type '{arg_type}'")
149145
}
150146
}
151147
}

0 commit comments

Comments
 (0)