Skip to content

Commit 448dff5

Browse files
authored
Fix ScalarValue handling of NULL values for ListArray (#7969)
* Fix try_from_array data type for NULL value in ListArray * Fix * Explicitly assert the datatype * For review
1 parent bb1d7f9 commit 448dff5

File tree

2 files changed

+111
-25
lines changed

2 files changed

+111
-25
lines changed

datafusion/common/src/scalar.rs

Lines changed: 100 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,10 +1312,11 @@ impl ScalarValue {
13121312
Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
13131313
scalars.into_iter().map(|x| match x {
13141314
ScalarValue::List(arr) => {
1315-
if arr.as_any().downcast_ref::<NullArray>().is_some() {
1315+
// `ScalarValue::List` contains a single element `ListArray`.
1316+
let list_arr = as_list_array(&arr);
1317+
if list_arr.is_null(0) {
13161318
None
13171319
} else {
1318-
let list_arr = as_list_array(&arr);
13191320
let primitive_arr =
13201321
list_arr.values().as_primitive::<$ARRAY_TY>();
13211322
Some(
@@ -1339,12 +1340,14 @@ impl ScalarValue {
13391340
for scalar in scalars.into_iter() {
13401341
match scalar {
13411342
ScalarValue::List(arr) => {
1342-
if arr.as_any().downcast_ref::<NullArray>().is_some() {
1343+
// `ScalarValue::List` contains a single element `ListArray`.
1344+
let list_arr = as_list_array(&arr);
1345+
1346+
if list_arr.is_null(0) {
13431347
builder.append(false);
13441348
continue;
13451349
}
13461350

1347-
let list_arr = as_list_array(&arr);
13481351
let string_arr = $STRING_ARRAY(list_arr.values());
13491352

13501353
for v in string_arr.iter() {
@@ -1699,15 +1702,16 @@ impl ScalarValue {
16991702

17001703
for scalar in scalars {
17011704
if let ScalarValue::List(arr) = scalar {
1702-
// i.e. NullArray(1)
1703-
if arr.as_any().downcast_ref::<NullArray>().is_some() {
1705+
// `ScalarValue::List` contains a single element `ListArray`.
1706+
let list_arr = as_list_array(&arr);
1707+
1708+
if list_arr.is_null(0) {
17041709
// Repeat previous offset index
17051710
offsets.push(0);
17061711

17071712
// Element is null
17081713
valid.append(false);
17091714
} else {
1710-
let list_arr = as_list_array(&arr);
17111715
let arr = list_arr.values().to_owned();
17121716
offsets.push(arr.len());
17131717
elements.push(arr);
@@ -2234,28 +2238,20 @@ impl ScalarValue {
22342238
}
22352239
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
22362240
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
2237-
DataType::List(nested_type) => {
2241+
DataType::List(_) => {
22382242
let list_array = as_list_array(array);
2239-
let arr = match list_array.is_null(index) {
2240-
true => new_null_array(nested_type.data_type(), 0),
2241-
false => {
2242-
let nested_array = list_array.value(index);
2243-
Arc::new(wrap_into_list_array(nested_array))
2244-
}
2245-
};
2243+
let nested_array = list_array.value(index);
2244+
// Produces a single element `ListArray` with the value at `index`.
2245+
let arr = Arc::new(wrap_into_list_array(nested_array));
22462246

22472247
ScalarValue::List(arr)
22482248
}
22492249
// TODO: There is no test for FixedSizeList now, add it later
2250-
DataType::FixedSizeList(nested_type, _len) => {
2250+
DataType::FixedSizeList(_, _) => {
22512251
let list_array = as_fixed_size_list_array(array)?;
2252-
let arr = match list_array.is_null(index) {
2253-
true => new_null_array(nested_type.data_type(), 0),
2254-
false => {
2255-
let nested_array = list_array.value(index);
2256-
Arc::new(wrap_into_list_array(nested_array))
2257-
}
2258-
};
2252+
let nested_array = list_array.value(index);
2253+
// Produces a single element `ListArray` with the value at `index`.
2254+
let arr = Arc::new(wrap_into_list_array(nested_array));
22592255

22602256
ScalarValue::List(arr)
22612257
}
@@ -2944,8 +2940,15 @@ impl TryFrom<&DataType> for ScalarValue {
29442940
index_type.clone(),
29452941
Box::new(value_type.as_ref().try_into()?),
29462942
),
2947-
DataType::List(_) => ScalarValue::List(new_null_array(&DataType::Null, 0)),
2948-
2943+
// `ScalaValue::List` contains single element `ListArray`.
2944+
DataType::List(field) => ScalarValue::List(new_null_array(
2945+
&DataType::List(Arc::new(Field::new(
2946+
"item",
2947+
field.data_type().clone(),
2948+
true,
2949+
))),
2950+
1,
2951+
)),
29492952
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
29502953
DataType::Null => ScalarValue::Null,
29512954
_ => {
@@ -3885,6 +3888,78 @@ mod tests {
38853888
);
38863889
}
38873890

3891+
#[test]
3892+
fn scalar_try_from_array_list_array_null() {
3893+
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
3894+
Some(vec![Some(1), Some(2)]),
3895+
None,
3896+
]);
3897+
3898+
let non_null_list_scalar = ScalarValue::try_from_array(&list, 0).unwrap();
3899+
let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap();
3900+
3901+
let data_type =
3902+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
3903+
3904+
assert_eq!(non_null_list_scalar.data_type(), data_type.clone());
3905+
assert_eq!(null_list_scalar.data_type(), data_type);
3906+
}
3907+
3908+
#[test]
3909+
fn scalar_try_from_list() {
3910+
let data_type =
3911+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
3912+
let data_type = &data_type;
3913+
let scalar: ScalarValue = data_type.try_into().unwrap();
3914+
3915+
let expected = ScalarValue::List(new_null_array(
3916+
&DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3917+
1,
3918+
));
3919+
3920+
assert_eq!(expected, scalar)
3921+
}
3922+
3923+
#[test]
3924+
fn scalar_try_from_list_of_list() {
3925+
let data_type = DataType::List(Arc::new(Field::new(
3926+
"item",
3927+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3928+
true,
3929+
)));
3930+
let data_type = &data_type;
3931+
let scalar: ScalarValue = data_type.try_into().unwrap();
3932+
3933+
let expected = ScalarValue::List(new_null_array(
3934+
&DataType::List(Arc::new(Field::new(
3935+
"item",
3936+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3937+
true,
3938+
))),
3939+
1,
3940+
));
3941+
3942+
assert_eq!(expected, scalar)
3943+
}
3944+
3945+
#[test]
3946+
fn scalar_try_from_not_equal_list_nested_list() {
3947+
let list_data_type =
3948+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
3949+
let data_type = &list_data_type;
3950+
let list_scalar: ScalarValue = data_type.try_into().unwrap();
3951+
3952+
let nested_list_data_type = DataType::List(Arc::new(Field::new(
3953+
"item",
3954+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3955+
true,
3956+
)));
3957+
let data_type = &nested_list_data_type;
3958+
let nested_list_scalar: ScalarValue = data_type.try_into().unwrap();
3959+
3960+
assert_ne!(list_scalar, nested_list_scalar);
3961+
}
3962+
38883963
#[test]
38893964
fn scalar_try_from_dict_datatype() {
38903965
let data_type =

datafusion/sqllogictest/test_files/array.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,17 @@ AS VALUES
209209
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10)
210210
;
211211

212+
query TTT
213+
select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays;
214+
----
215+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
216+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
217+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
218+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
219+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
220+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
221+
List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
222+
212223
# arrays table
213224
query ???
214225
select column1, column2, column3 from arrays;

0 commit comments

Comments
 (0)