Skip to content

Commit be32301

Browse files
committed
fix sql_array_literal
Signed-off-by: veeupup <code@tanweime.com>
1 parent c08d6cb commit be32301

File tree

3 files changed

+51
-91
lines changed

3 files changed

+51
-91
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -599,19 +599,10 @@ impl BuiltinScalarFunction {
599599
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
600600
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
601601
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
602-
BuiltinScalarFunction::ArrayIntersect => {
602+
BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => {
603603
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
604-
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
605-
Field::new("item", DataType::Null, true),
606-
))),
607-
(dt, _) => Ok(dt),
608-
}
609-
}
610-
BuiltinScalarFunction::ArrayUnion => {
611-
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
612-
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
613-
Field::new("item", DataType::Null, true),
614-
))),
604+
(DataType::Null, dt) => Ok(dt),
605+
(dt, DataType::Null) => Ok(dt),
615606
(dt, _) => Ok(dt),
616607
}
617608
}
@@ -620,9 +611,9 @@ impl BuiltinScalarFunction {
620611
}
621612
BuiltinScalarFunction::ArrayExcept => {
622613
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
623-
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
624-
Field::new("item", DataType::Null, true),
625-
))),
614+
(DataType::Null, _) | (_, DataType::Null) => {
615+
Ok(input_expr_types[0].clone())
616+
}
626617
(dt, _) => Ok(dt),
627618
}
628619
}

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>>
228228

229229
fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
230230
let data_type = args[0].data_type();
231-
if !args
232-
.iter()
233-
.all(|arg| arg.data_type().equals_datatype(data_type))
234-
{
231+
if !args.iter().all(|arg| {
232+
arg.data_type().equals_datatype(data_type)
233+
|| arg.data_type().equals_datatype(&DataType::Null)
234+
}) {
235235
let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
236236
return plan_err!("{name} received incompatible types: '{types:?}'.");
237237
}
@@ -580,21 +580,6 @@ pub fn array_except(args: &[ArrayRef]) -> Result<ArrayRef> {
580580
let array2 = &args[1];
581581

582582
match (array1.data_type(), array2.data_type()) {
583-
(DataType::Null, DataType::Null) => {
584-
// NullArray(1): means null, NullArray(0): means []
585-
// except([], []) = [], except([], null) = [], except(null, []) = null, except(null, null) = null
586-
let nulls = match (array1.len(), array2.len()) {
587-
(1, _) => Some(NullBuffer::new_null(1)),
588-
_ => None,
589-
};
590-
let arr = Arc::new(ListArray::try_new(
591-
Arc::new(Field::new("item", DataType::Null, true)),
592-
OffsetBuffer::new(vec![0; 2].into()),
593-
Arc::new(NullArray::new(0)),
594-
nulls,
595-
)?) as ArrayRef;
596-
Ok(arr)
597-
}
598583
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),
599584
(DataType::List(field), DataType::List(_)) => {
600585
check_datatypes("array_except", &[array1, array2])?;
@@ -1525,36 +1510,31 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
15251510
let array1 = &args[0];
15261511
let array2 = &args[1];
15271512
match (array1.data_type(), array2.data_type()) {
1528-
(DataType::Null, DataType::Null) => {
1529-
// NullArray(1): means null, NullArray(0): means []
1530-
// union([], []) = [], union([], null) = [], union(null, []) = [], union(null, null) = null
1531-
let nulls = match (array1.len(), array2.len()) {
1532-
(1, 1) => Some(NullBuffer::new_null(1)),
1533-
_ => None,
1534-
};
1535-
let arr = Arc::new(ListArray::try_new(
1536-
Arc::new(Field::new("item", DataType::Null, true)),
1537-
OffsetBuffer::new(vec![0; 2].into()),
1538-
Arc::new(NullArray::new(0)),
1539-
nulls,
1540-
)?) as ArrayRef;
1541-
Ok(arr)
1542-
}
15431513
(DataType::Null, _) => Ok(array2.clone()),
15441514
(_, DataType::Null) => Ok(array1.clone()),
1545-
(DataType::List(field_ref), DataType::List(_)) => {
1546-
check_datatypes("array_union", &[array1, array2])?;
1547-
let list1 = array1.as_list::<i32>();
1548-
let list2 = array2.as_list::<i32>();
1549-
let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
1550-
Ok(Arc::new(result))
1515+
(DataType::List(l_field_ref), DataType::List(r_field_ref)) => {
1516+
match (l_field_ref.data_type(), r_field_ref.data_type()) {
1517+
(DataType::Null, _) => Ok(array2.clone()),
1518+
(_, DataType::Null) => Ok(array1.clone()),
1519+
(_, _) => {
1520+
let list1 = array1.as_list::<i32>();
1521+
let list2 = array2.as_list::<i32>();
1522+
let result = union_generic_lists::<i32>(list1, list2, &l_field_ref)?;
1523+
Ok(Arc::new(result))
1524+
}
1525+
}
15511526
}
1552-
(DataType::LargeList(field_ref), DataType::LargeList(_)) => {
1553-
check_datatypes("array_union", &[array1, array2])?;
1554-
let list1 = array1.as_list::<i64>();
1555-
let list2 = array2.as_list::<i64>();
1556-
let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
1557-
Ok(Arc::new(result))
1527+
(DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => {
1528+
match (l_field_ref.data_type(), r_field_ref.data_type()) {
1529+
(DataType::Null, _) => Ok(array2.clone()),
1530+
(_, DataType::Null) => Ok(array1.clone()),
1531+
(_, _) => {
1532+
let list1 = array1.as_list::<i64>();
1533+
let list2 = array2.as_list::<i64>();
1534+
let result = union_generic_lists::<i64>(list1, list2, &l_field_ref)?;
1535+
Ok(Arc::new(result))
1536+
}
1537+
}
15581538
}
15591539
_ => {
15601540
internal_err!(
@@ -2032,21 +2012,8 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
20322012
let second_array = &args[1];
20332013

20342014
match (first_array.data_type(), second_array.data_type()) {
2035-
(DataType::Null, DataType::Null) => {
2036-
// NullArray(1): means null, NullArray(0): means []
2037-
// intersect([], []) = [], intersect([], null) = [], intersect(null, []) = [], intersect(null, null) = null
2038-
let nulls = match (first_array.len(), second_array.len()) {
2039-
(1, 1) => Some(NullBuffer::new_null(1)),
2040-
_ => None,
2041-
};
2042-
let arr = Arc::new(ListArray::try_new(
2043-
Arc::new(Field::new("item", DataType::Null, true)),
2044-
OffsetBuffer::new(vec![0; 2].into()),
2045-
Arc::new(NullArray::new(0)),
2046-
nulls,
2047-
)?) as ArrayRef;
2048-
Ok(arr)
2049-
}
2015+
(DataType::Null, _) => Ok(second_array.clone()),
2016+
(_, DataType::Null) => Ok(first_array.clone()),
20502017
_ => {
20512018
let first_array = as_list_array(&first_array)?;
20522019
let second_array = as_list_array(&second_array)?;

datafusion/sql/src/expr/value.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@
1616
// under the License.
1717

1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
19-
use arrow::array::new_null_array;
2019
use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano;
2120
use arrow::datatypes::DECIMAL128_MAX_PRECISION;
2221
use arrow_schema::DataType;
2322
use datafusion_common::{
2423
not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue,
2524
};
25+
use datafusion_expr::expr::ScalarFunction;
2626
use datafusion_expr::expr::{BinaryExpr, Placeholder};
27+
use datafusion_expr::BuiltinScalarFunction;
2728
use datafusion_expr::{lit, Expr, Operator};
2829
use log::debug;
2930
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
3031
use sqlparser::parser::ParserError::ParserError;
3132
use std::borrow::Cow;
32-
use std::collections::HashSet;
3333

3434
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
3535
pub(crate) fn parse_value(
@@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
138138
schema,
139139
&mut PlannerContext::new(),
140140
)?;
141+
141142
match value {
142143
Expr::Literal(scalar) => {
143-
values.push(scalar);
144+
values.push(Expr::Literal(scalar));
145+
}
146+
Expr::ScalarFunction(ref scalar_function) => {
147+
if scalar_function.fun == BuiltinScalarFunction::MakeArray {
148+
values.push(Expr::ScalarFunction(scalar_function.clone()));
149+
} else {
150+
return not_impl_err!(
151+
"ScalarFunctions without MakeArray are not supported: {value}"
152+
);
153+
}
144154
}
145155
_ => {
146156
return not_impl_err!(
@@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
150160
}
151161
}
152162

153-
let data_types: HashSet<DataType> =
154-
values.iter().map(|e| e.data_type()).collect();
155-
156-
if data_types.is_empty() {
157-
Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0))))
158-
} else if data_types.len() > 1 {
159-
not_impl_err!("Arrays with different types are not supported: {data_types:?}")
160-
} else {
161-
let data_type = values[0].data_type();
162-
let arr = ScalarValue::new_list(&values, &data_type);
163-
Ok(lit(ScalarValue::List(arr)))
164-
}
163+
Ok(Expr::ScalarFunction(ScalarFunction::new(
164+
BuiltinScalarFunction::MakeArray,
165+
values,
166+
)))
165167
}
166168

167169
/// Convert a SQL interval expression to a DataFusion logical plan

0 commit comments

Comments
 (0)