Skip to content

Commit bfce076

Browse files
authored
Refactor downcasting functions with downcastvalue macro and improve error handling of ListArray downcasting (#4313)
* refactor casting with downcastvalue macro and add list array downcasting * fix clippy
1 parent 92325bf commit bfce076

File tree

7 files changed

+41
-126
lines changed

7 files changed

+41
-126
lines changed

datafusion/common/src/cast.rs

Lines changed: 19 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -20,132 +20,71 @@
2020
//! but provide an error message rather than a panic, as the corresponding
2121
//! kernels in arrow-rs such as `as_boolean_array` do.
2222
23-
use crate::DataFusionError;
23+
use crate::{downcast_value, DataFusionError};
2424
use arrow::array::{
2525
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
26-
Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array,
26+
Int32Array, Int64Array, ListArray, StringArray, StructArray, UInt32Array,
27+
UInt64Array,
2728
};
2829

2930
// Downcast ArrayRef to Date32Array
3031
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
31-
array.as_any().downcast_ref::<Date32Array>().ok_or_else(|| {
32-
DataFusionError::Internal(format!(
33-
"Expected a Date32Array, got: {}",
34-
array.data_type()
35-
))
36-
})
32+
Ok(downcast_value!(array, Date32Array))
3733
}
3834

3935
// Downcast ArrayRef to StructArray
4036
pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
41-
array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
42-
DataFusionError::Internal(format!(
43-
"Expected a StructArray, got: {}",
44-
array.data_type()
45-
))
46-
})
37+
Ok(downcast_value!(array, StructArray))
4738
}
4839

4940
// Downcast ArrayRef to Int32Array
5041
pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array, DataFusionError> {
51-
array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
52-
DataFusionError::Internal(format!(
53-
"Expected a Int32Array, got: {}",
54-
array.data_type()
55-
))
56-
})
42+
Ok(downcast_value!(array, Int32Array))
5743
}
5844

5945
// Downcast ArrayRef to Int64Array
6046
pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array, DataFusionError> {
61-
array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
62-
DataFusionError::Internal(format!(
63-
"Expected a Int64Array, got: {}",
64-
array.data_type()
65-
))
66-
})
47+
Ok(downcast_value!(array, Int64Array))
6748
}
6849

6950
// Downcast ArrayRef to Decimal128Array
7051
pub fn as_decimal128_array(
7152
array: &dyn Array,
7253
) -> Result<&Decimal128Array, DataFusionError> {
73-
array
74-
.as_any()
75-
.downcast_ref::<Decimal128Array>()
76-
.ok_or_else(|| {
77-
DataFusionError::Internal(format!(
78-
"Expected a Decimal128Array, got: {}",
79-
array.data_type()
80-
))
81-
})
54+
Ok(downcast_value!(array, Decimal128Array))
8255
}
8356

8457
// Downcast ArrayRef to Float32Array
8558
pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array, DataFusionError> {
86-
array
87-
.as_any()
88-
.downcast_ref::<Float32Array>()
89-
.ok_or_else(|| {
90-
DataFusionError::Internal(format!(
91-
"Expected a Float32Array, got: {}",
92-
array.data_type()
93-
))
94-
})
59+
Ok(downcast_value!(array, Float32Array))
9560
}
9661

9762
// Downcast ArrayRef to Float64Array
9863
pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array, DataFusionError> {
99-
array
100-
.as_any()
101-
.downcast_ref::<Float64Array>()
102-
.ok_or_else(|| {
103-
DataFusionError::Internal(format!(
104-
"Expected a Float64Array, got: {}",
105-
array.data_type()
106-
))
107-
})
64+
Ok(downcast_value!(array, Float64Array))
10865
}
10966

11067
// Downcast ArrayRef to StringArray
11168
pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionError> {
112-
array.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
113-
DataFusionError::Internal(format!(
114-
"Expected a StringArray, got: {}",
115-
array.data_type()
116-
))
117-
})
69+
Ok(downcast_value!(array, StringArray))
11870
}
11971

12072
// Downcast ArrayRef to UInt32Array
12173
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array, DataFusionError> {
122-
array.as_any().downcast_ref::<UInt32Array>().ok_or_else(|| {
123-
DataFusionError::Internal(format!(
124-
"Expected a UInt32Array, got: {}",
125-
array.data_type()
126-
))
127-
})
74+
Ok(downcast_value!(array, UInt32Array))
12875
}
12976

13077
// Downcast ArrayRef to UInt64Array
13178
pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array, DataFusionError> {
132-
array.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
133-
DataFusionError::Internal(format!(
134-
"Expected a UInt64Array, got: {}",
135-
array.data_type()
136-
))
137-
})
79+
Ok(downcast_value!(array, UInt64Array))
13880
}
13981

14082
// Downcast ArrayRef to BooleanArray
14183
pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionError> {
142-
array
143-
.as_any()
144-
.downcast_ref::<BooleanArray>()
145-
.ok_or_else(|| {
146-
DataFusionError::Internal(format!(
147-
"Expected a BooleanArray, got: {}",
148-
array.data_type()
149-
))
150-
})
84+
Ok(downcast_value!(array, BooleanArray))
85+
}
86+
87+
// Downcast ArrayRef to ListArray
88+
pub fn as_list_array(array: &dyn Array) -> Result<&ListArray, DataFusionError> {
89+
Ok(downcast_value!(array, ListArray))
15190
}

datafusion/common/src/scalar.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::ops::{Add, Sub};
2424
use std::str::FromStr;
2525
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
2626

27-
use crate::cast::{as_decimal128_array, as_struct_array};
27+
use crate::cast::{as_decimal128_array, as_list_array, as_struct_array};
2828
use crate::delta::shift_months;
2929
use crate::error::{DataFusionError, Result};
3030
use arrow::{
@@ -2001,12 +2001,7 @@ impl ScalarValue {
20012001
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
20022002
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
20032003
DataType::List(nested_type) => {
2004-
let list_array =
2005-
array.as_any().downcast_ref::<ListArray>().ok_or_else(|| {
2006-
DataFusionError::Internal(
2007-
"Failed to downcast ListArray".to_string(),
2008-
)
2009-
})?;
2004+
let list_array = as_list_array(array)?;
20102005
let value = match list_array.is_null(index) {
20112006
true => None,
20122007
false => {
@@ -2940,7 +2935,7 @@ mod tests {
29402935
Box::new(Field::new("item", DataType::UInt64, false)),
29412936
)
29422937
.to_array();
2943-
let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
2938+
let list_array = as_list_array(&list_array_ref).unwrap();
29442939

29452940
assert!(list_array.is_null(0));
29462941
assert_eq!(list_array.len(), 1);
@@ -2959,7 +2954,7 @@ mod tests {
29592954
)
29602955
.to_array();
29612956

2962-
let list_array = list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
2957+
let list_array = as_list_array(&list_array_ref)?;
29632958
assert_eq!(list_array.len(), 1);
29642959
assert_eq!(list_array.values().len(), 3);
29652960

@@ -3758,7 +3753,7 @@ mod tests {
37583753
let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype());
37593754
// iter_to_array for list-of-struct
37603755
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
3761-
let array = array.as_any().downcast_ref::<ListArray>().unwrap();
3756+
let array = as_list_array(&array).unwrap();
37623757

37633758
// Construct expected array with array builders
37643759
let field_a_builder = StringBuilder::with_capacity(4, 1024);
@@ -3922,7 +3917,7 @@ mod tests {
39223917
);
39233918

39243919
let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
3925-
let array = array.as_any().downcast_ref::<ListArray>().unwrap();
3920+
let array = as_list_array(&array).unwrap();
39263921

39273922
// Construct expected array with array builders
39283923
let inner_builder = Int32Array::builder(8);

datafusion/core/src/avro_to_arrow/arrow_array_reader.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -975,9 +975,9 @@ mod test {
975975
use crate::arrow::array::Array;
976976
use crate::arrow::datatypes::{Field, TimeUnit};
977977
use crate::avro_to_arrow::{Reader, ReaderBuilder};
978-
use arrow::array::{ListArray, TimestampMicrosecondArray};
978+
use arrow::array::TimestampMicrosecondArray;
979979
use arrow::datatypes::DataType;
980-
use datafusion_common::cast::{as_int32_array, as_int64_array};
980+
use datafusion_common::cast::{as_int32_array, as_int64_array, as_list_array};
981981
use std::fs::File;
982982

983983
fn build_reader(name: &str, batch_size: usize) -> Reader<File> {
@@ -1034,11 +1034,7 @@ mod test {
10341034
let batch = reader.next().unwrap().unwrap();
10351035
assert_eq!(batch.num_columns(), 2);
10361036
assert_eq!(batch.num_rows(), 3);
1037-
let a_array = batch
1038-
.column(col_id_index)
1039-
.as_any()
1040-
.downcast_ref::<ListArray>()
1041-
.unwrap();
1037+
let a_array = as_list_array(batch.column(col_id_index)).unwrap();
10421038
assert_eq!(
10431039
*a_array.data_type(),
10441040
DataType::List(Box::new(Field::new("bigint", DataType::Int64, true)))

datafusion/core/tests/sql/parquet.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use std::{fs, path::Path};
1919

2020
use ::parquet::arrow::ArrowWriter;
2121
use datafusion::datasource::listing::ListingOptions;
22-
use datafusion_common::cast::as_string_array;
22+
use datafusion_common::cast::{as_list_array, as_string_array};
2323
use tempfile::TempDir;
2424

2525
use super::*;
@@ -235,16 +235,8 @@ async fn parquet_list_columns() {
235235
assert_eq!(2, batch.num_columns());
236236
assert_eq!(schema, batch.schema());
237237

238-
let int_list_array = batch
239-
.column(0)
240-
.as_any()
241-
.downcast_ref::<ListArray>()
242-
.unwrap();
243-
let utf8_list_array = batch
244-
.column(1)
245-
.as_any()
246-
.downcast_ref::<ListArray>()
247-
.unwrap();
238+
let int_list_array = as_list_array(batch.column(0)).unwrap();
239+
let utf8_list_array = as_list_array(batch.column(1)).unwrap();
248240

249241
assert_eq!(
250242
int_list_array

datafusion/physical-expr/src/aggregate/count_distinct.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ mod tests {
226226
use crate::aggregate::utils::get_accum_scalar_values;
227227
use arrow::array::{
228228
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
229-
Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array,
230-
UInt8Array,
229+
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
231230
};
232231
use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
233232
use arrow::datatypes::DataType;
233+
use datafusion_common::cast::as_list_array;
234234

235235
macro_rules! state_to_vec {
236236
($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
@@ -380,7 +380,7 @@ mod tests {
380380
let agg = DistinctCount::new(
381381
arrays
382382
.iter()
383-
.map(|a| a.as_any().downcast_ref::<ListArray>().unwrap())
383+
.map(|a| as_list_array(a).unwrap())
384384
.map(|a| a.values().data_type().clone())
385385
.collect::<Vec<_>>(),
386386
vec![],

datafusion/physical-expr/src/expressions/get_indexed_field.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
2020
use crate::PhysicalExpr;
2121
use arrow::array::Array;
22-
use arrow::array::ListArray;
2322
use arrow::compute::concat;
2423

2524
use crate::physical_expr::down_cast_any_ref;
2625
use arrow::{
2726
datatypes::{DataType, Schema},
2827
record_batch::RecordBatch,
2928
};
30-
use datafusion_common::cast::as_struct_array;
29+
use datafusion_common::cast::{as_list_array, as_struct_array};
3130
use datafusion_common::DataFusionError;
3231
use datafusion_common::Result;
3332
use datafusion_common::ScalarValue;
@@ -91,8 +90,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
9190
Ok(ColumnarValue::Scalar(scalar_null))
9291
}
9392
(DataType::List(_), ScalarValue::Int64(Some(i))) => {
94-
let as_list_array =
95-
array.as_any().downcast_ref::<ListArray>().unwrap();
93+
let as_list_array = as_list_array(&array)?;
9694

9795
if *i < 1 || as_list_array.is_empty() {
9896
let scalar_null: ScalarValue = array.data_type().try_into()?;
@@ -349,10 +347,7 @@ mod tests {
349347
let get_list_expr =
350348
Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key));
351349
let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
352-
let result = result
353-
.as_any()
354-
.downcast_ref::<ListArray>()
355-
.unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result));
350+
let result = as_list_array(&result)?;
356351
let expected =
357352
&build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect());
358353
assert_eq!(expected, result);

datafusion/physical-expr/src/functions.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,8 +2847,7 @@ mod tests {
28472847
#[test]
28482848
#[cfg(feature = "regex_expressions")]
28492849
fn test_regexp_match() -> Result<()> {
2850-
use arrow::array::ListArray;
2851-
use datafusion_common::cast::as_string_array;
2850+
use datafusion_common::cast::{as_list_array, as_string_array};
28522851
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
28532852
let execution_props = ExecutionProps::new();
28542853

@@ -2873,7 +2872,7 @@ mod tests {
28732872
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
28742873

28752874
// downcast works
2876-
let result = result.as_any().downcast_ref::<ListArray>().unwrap();
2875+
let result = as_list_array(&result)?;
28772876
let first_row = result.value(0);
28782877
let first_row = as_string_array(&first_row)?;
28792878

@@ -2887,8 +2886,7 @@ mod tests {
28872886
#[test]
28882887
#[cfg(feature = "regex_expressions")]
28892888
fn test_regexp_match_all_literals() -> Result<()> {
2890-
use arrow::array::ListArray;
2891-
use datafusion_common::cast::as_string_array;
2889+
use datafusion_common::cast::{as_list_array, as_string_array};
28922890
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
28932891
let execution_props = ExecutionProps::new();
28942892

@@ -2913,7 +2911,7 @@ mod tests {
29132911
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
29142912

29152913
// downcast works
2916-
let result = result.as_any().downcast_ref::<ListArray>().unwrap();
2914+
let result = as_list_array(&result)?;
29172915
let first_row = result.value(0);
29182916
let first_row = as_string_array(&first_row)?;
29192917

0 commit comments

Comments
 (0)