|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use super::binary::{binary_numeric_coercion, comparison_coercion}; |
| 18 | +use super::binary::binary_numeric_coercion; |
19 | 19 | use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; |
20 | 20 | use arrow::{ |
21 | 21 | compute::can_cast_types, |
22 | | - datatypes::{DataType, Field, TimeUnit}, |
| 22 | + datatypes::{DataType, TimeUnit}, |
23 | 23 | }; |
24 | 24 | use datafusion_common::types::LogicalType; |
25 | | -use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; |
| 25 | +use datafusion_common::utils::{ |
| 26 | + base_type, coerced_fixed_size_list_to_list, ListCoercion, |
| 27 | +}; |
26 | 28 | use datafusion_common::{ |
27 | | - exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, |
28 | | - utils::list_ndims, Result, |
| 29 | + exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result, |
29 | 30 | }; |
30 | 31 | use datafusion_expr_common::signature::ArrayFunctionArgument; |
| 32 | +use datafusion_expr_common::type_coercion::binary::type_union_resolution; |
31 | 33 | use datafusion_expr_common::{ |
32 | 34 | signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, |
33 | 35 | type_coercion::binary::comparison_coercion_numeric, |
@@ -364,98 +366,73 @@ fn get_valid_types( |
364 | 366 | return Ok(vec![vec![]]); |
365 | 367 | } |
366 | 368 |
|
367 | | - let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { |
368 | | - if *arg == ArrayFunctionArgument::Array { |
369 | | - Some(idx) |
370 | | - } else { |
371 | | - None |
372 | | - } |
373 | | - }); |
374 | | - let Some(array_idx) = array_idx else { |
375 | | - return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); |
376 | | - }; |
377 | | - let Some(array_type) = array(¤t_types[array_idx]) else { |
378 | | - return Ok(vec![vec![]]); |
379 | | - }; |
| 369 | + let mut fixed_size = None; |
| 370 | + let mut large_list = false; |
| 371 | + let mut element_types = Vec::with_capacity(arguments.len()); |
| 372 | + for (argument, current_type) in arguments.iter().zip(current_types.iter()) { |
| 373 | + match argument { |
| 374 | + ArrayFunctionArgument::Array => match current_type { |
| 375 | + DataType::FixedSizeList(field, size) => { |
| 376 | + match array_coercion { |
| 377 | + Some(ListCoercion::FixedSizedListToList) => (), |
| 378 | + None if fixed_size.is_none() => fixed_size = Some(*size), |
| 379 | + None if fixed_size == Some(*size) => (), |
| 380 | + None => fixed_size = None, |
| 381 | + } |
380 | 382 |
|
381 | | - // We need to find the coerced base type, mainly for cases like: |
382 | | - // `array_append(List(null), i64)` -> `List(i64)` |
383 | | - let mut new_base_type = datafusion_common::utils::base_type(&array_type); |
384 | | - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { |
385 | | - match argument_type { |
386 | | - ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { |
387 | | - new_base_type = |
388 | | - coerce_array_types(function_name, current_type, &new_base_type)?; |
| 383 | + element_types.push(field.data_type().clone()) |
| 384 | + } |
| 385 | + DataType::List(field) => { |
| 386 | + fixed_size = None; |
| 387 | + element_types.push(field.data_type().clone()) |
| 388 | + } |
| 389 | + DataType::LargeList(field) => { |
| 390 | + fixed_size = None; |
| 391 | + large_list = true; |
| 392 | + element_types.push(field.data_type().clone()) |
| 393 | + } |
| 394 | + DataType::Null => { |
| 395 | + fixed_size = None; |
| 396 | + element_types.push(DataType::Null) |
| 397 | + } |
| 398 | + arg_type => { |
| 399 | + return plan_err!( |
| 400 | + "{function_name} does not support an argument of type {arg_type}" |
| 401 | + ) |
| 402 | + } |
| 403 | + }, |
| 404 | + ArrayFunctionArgument::Element => { |
| 405 | + element_types.push(current_type.clone()) |
389 | 406 | } |
390 | | - ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {} |
| 407 | + ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (), |
391 | 408 | } |
392 | 409 | } |
393 | | - let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( |
394 | | - &array_type, |
395 | | - &new_base_type, |
396 | | - array_coercion, |
397 | | - ); |
398 | 410 |
|
399 | | - let new_elem_type = match new_array_type { |
400 | | - DataType::List(ref field) |
401 | | - | DataType::LargeList(ref field) |
402 | | - | DataType::FixedSizeList(ref field, _) => field.data_type(), |
403 | | - _ => return Ok(vec![vec![]]), |
| 411 | + let Some(element_type) = type_union_resolution(&element_types) else { |
| 412 | + return plan_err!( |
| 413 | + "Failed to unify argument types of {function_name}: {current_types:?}." |
| 414 | + ); |
404 | 415 | }; |
405 | 416 |
|
406 | | - let mut valid_types = Vec::with_capacity(arguments.len()); |
407 | | - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { |
408 | | - let valid_type = match argument_type { |
409 | | - ArrayFunctionArgument::Element => new_elem_type.clone(), |
| 417 | + let array_type = if large_list { |
| 418 | + DataType::new_large_list(element_type.clone(), true) |
| 419 | + } else if let Some(size) = fixed_size { |
| 420 | + DataType::new_fixed_size_list(element_type.clone(), size, true) |
| 421 | + } else { |
| 422 | + DataType::new_list(element_type.clone(), true) |
| 423 | + }; |
| 424 | + |
| 425 | + let valid_types = arguments.iter().zip(current_types.iter()).map( |
| 426 | + |(argument_type, current_type)| match argument_type { |
| 427 | + ArrayFunctionArgument::Array if current_type.is_null() => DataType::Null, |
| 428 | + ArrayFunctionArgument::Array => array_type.clone(), |
| 429 | + ArrayFunctionArgument::Element => element_type.clone(), |
410 | 430 | ArrayFunctionArgument::Index => DataType::Int64, |
411 | 431 | ArrayFunctionArgument::String => DataType::Utf8, |
412 | | - ArrayFunctionArgument::Array => { |
413 | | - let Some(current_type) = array(current_type) else { |
414 | | - return Ok(vec![vec![]]); |
415 | | - }; |
416 | | - let new_type = |
417 | | - datafusion_common::utils::coerced_type_with_base_type_only( |
418 | | - ¤t_type, |
419 | | - &new_base_type, |
420 | | - array_coercion, |
421 | | - ); |
422 | | - // All array arguments must be coercible to the same type |
423 | | - if new_type != new_array_type { |
424 | | - return Ok(vec![vec![]]); |
425 | | - } |
426 | | - new_type |
427 | | - } |
428 | | - }; |
429 | | - valid_types.push(valid_type); |
430 | | - } |
431 | | - |
432 | | - Ok(vec![valid_types]) |
433 | | - } |
434 | | - |
435 | | - fn array(array_type: &DataType) -> Option<DataType> { |
436 | | - match array_type { |
437 | | - DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), |
438 | | - DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), |
439 | | - DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field( |
440 | | - DataType::Int64, |
441 | | - true, |
442 | | - )))), |
443 | | - _ => None, |
444 | | - } |
445 | | - } |
| 432 | + }, |
| 433 | + ); |
446 | 434 |
|
447 | | - fn coerce_array_types( |
448 | | - function_name: &str, |
449 | | - current_type: &DataType, |
450 | | - base_type: &DataType, |
451 | | - ) -> Result<DataType> { |
452 | | - let current_base_type = datafusion_common::utils::base_type(current_type); |
453 | | - let new_base_type = comparison_coercion(base_type, ¤t_base_type); |
454 | | - new_base_type.ok_or_else(|| { |
455 | | - internal_datafusion_err!( |
456 | | - "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" |
457 | | - ) |
458 | | - }) |
| 435 | + Ok(vec![valid_types.collect()]) |
459 | 436 | } |
460 | 437 |
|
461 | 438 | fn recursive_array(array_type: &DataType) -> Option<DataType> { |
@@ -800,7 +777,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { |
800 | 777 | /// |
801 | 778 | /// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. |
802 | 779 | /// |
803 | | -/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. |
| 780 | +/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion. |
804 | 781 | fn coerced_from<'a>( |
805 | 782 | type_into: &'a DataType, |
806 | 783 | type_from: &'a DataType, |
@@ -867,7 +844,7 @@ fn coerced_from<'a>( |
867 | 844 | // Only accept list and largelist with the same number of dimensions unless the type is Null. |
868 | 845 | // List or LargeList with different dimensions should be handled in TypeSignature or other places before this |
869 | 846 | (List(_) | LargeList(_), _) |
870 | | - if datafusion_common::utils::base_type(type_from).eq(&Null) |
| 847 | + if base_type(type_from).is_null() |
871 | 848 | || list_ndims(type_from) == list_ndims(type_into) => |
872 | 849 | { |
873 | 850 | Some(type_into.clone()) |
@@ -906,7 +883,6 @@ fn coerced_from<'a>( |
906 | 883 |
|
907 | 884 | #[cfg(test)] |
908 | 885 | mod tests { |
909 | | - |
910 | 886 | use crate::Volatility; |
911 | 887 |
|
912 | 888 | use super::*; |
|
0 commit comments