Skip to content

Commit 712b9fd

Browse files
authored
improve error messages while downcasting uint and boolean array (#4261)
1 parent 880e6fc commit 712b9fd

File tree

19 files changed

+128
-130
lines changed

19 files changed

+128
-130
lines changed

datafusion/common/src/cast.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
2323
use crate::DataFusionError;
2424
use arrow::array::{
25-
Array, Date32Array, Decimal128Array, Float32Array, Float64Array, Int32Array,
26-
Int64Array, StringArray, StructArray,
25+
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array,
26+
Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array,
2727
};
2828

2929
// Downcast ArrayRef to Date32Array
@@ -116,3 +116,36 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray, DataFusionErro
116116
))
117117
})
118118
}
119+
120+
// Downcast ArrayRef to UInt32Array
121+
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array, DataFusionError> {
122+
array.as_any().downcast_ref::<UInt32Array>().ok_or_else(|| {
123+
DataFusionError::Internal(format!(
124+
"Expected a UInt32Array, got: {}",
125+
array.data_type()
126+
))
127+
})
128+
}
129+
130+
// Downcast ArrayRef to UInt64Array
131+
pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array, DataFusionError> {
132+
array.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
133+
DataFusionError::Internal(format!(
134+
"Expected a UInt64Array, got: {}",
135+
array.data_type()
136+
))
137+
})
138+
}
139+
140+
// Downcast ArrayRef to BooleanArray
141+
pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, DataFusionError> {
142+
array
143+
.as_any()
144+
.downcast_ref::<BooleanArray>()
145+
.ok_or_else(|| {
146+
DataFusionError::Internal(format!(
147+
"Expected a BooleanArray, got: {}",
148+
array.data_type()
149+
))
150+
})
151+
}

datafusion/common/src/scalar.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2657,7 +2657,7 @@ mod tests {
26572657
use arrow::compute::kernels;
26582658
use arrow::datatypes::ArrowPrimitiveType;
26592659

2660-
use crate::cast::as_string_array;
2660+
use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
26612661
use crate::from_slice::FromSlice;
26622662

26632663
use super::*;
@@ -2792,35 +2792,37 @@ mod tests {
27922792
}
27932793

27942794
#[test]
2795-
fn scalar_value_to_array_u64() {
2795+
fn scalar_value_to_array_u64() -> Result<()> {
27962796
let value = ScalarValue::UInt64(Some(13u64));
27972797
let array = value.to_array();
2798-
let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
2798+
let array = as_uint64_array(&array)?;
27992799
assert_eq!(array.len(), 1);
28002800
assert!(!array.is_null(0));
28012801
assert_eq!(array.value(0), 13);
28022802

28032803
let value = ScalarValue::UInt64(None);
28042804
let array = value.to_array();
2805-
let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
2805+
let array = as_uint64_array(&array)?;
28062806
assert_eq!(array.len(), 1);
28072807
assert!(array.is_null(0));
2808+
Ok(())
28082809
}
28092810

28102811
#[test]
2811-
fn scalar_value_to_array_u32() {
2812+
fn scalar_value_to_array_u32() -> Result<()> {
28122813
let value = ScalarValue::UInt32(Some(13u32));
28132814
let array = value.to_array();
2814-
let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
2815+
let array = as_uint32_array(&array)?;
28152816
assert_eq!(array.len(), 1);
28162817
assert!(!array.is_null(0));
28172818
assert_eq!(array.value(0), 13);
28182819

28192820
let value = ScalarValue::UInt32(None);
28202821
let array = value.to_array();
2821-
let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
2822+
let array = as_uint32_array(&array)?;
28222823
assert_eq!(array.len(), 1);
28232824
assert!(array.is_null(0));
2825+
Ok(())
28242826
}
28252827

28262828
#[test]
@@ -2838,7 +2840,7 @@ mod tests {
28382840
}
28392841

28402842
#[test]
2841-
fn scalar_list_to_array() {
2843+
fn scalar_list_to_array() -> Result<()> {
28422844
let list_array_ref = ScalarValue::List(
28432845
Some(vec![
28442846
ScalarValue::UInt64(Some(100)),
@@ -2854,14 +2856,12 @@ mod tests {
28542856
assert_eq!(list_array.values().len(), 3);
28552857

28562858
let prim_array_ref = list_array.value(0);
2857-
let prim_array = prim_array_ref
2858-
.as_any()
2859-
.downcast_ref::<UInt64Array>()
2860-
.unwrap();
2859+
let prim_array = as_uint64_array(&prim_array_ref)?;
28612860
assert_eq!(prim_array.len(), 3);
28622861
assert_eq!(prim_array.value(0), 100);
28632862
assert!(prim_array.is_null(1));
28642863
assert_eq!(prim_array.value(2), 101);
2864+
Ok(())
28652865
}
28662866

28672867
/// Creates array directly and via ScalarValue and ensures they are the same

datafusion/core/src/datasource/file_format/avro.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ mod tests {
9292
use crate::datasource::file_format::test_util::scan_format;
9393
use crate::physical_plan::collect;
9494
use crate::prelude::{SessionConfig, SessionContext};
95-
use arrow::array::{BinaryArray, BooleanArray, TimestampMicrosecondArray};
96-
use datafusion_common::cast::{as_float32_array, as_float64_array, as_int32_array};
95+
use arrow::array::{BinaryArray, TimestampMicrosecondArray};
96+
use datafusion_common::cast::{
97+
as_boolean_array, as_float32_array, as_float64_array, as_int32_array,
98+
};
9799
use futures::StreamExt;
98100

99101
#[tokio::test]
@@ -197,11 +199,7 @@ mod tests {
197199
assert_eq!(1, batches[0].num_columns());
198200
assert_eq!(8, batches[0].num_rows());
199201

200-
let array = batches[0]
201-
.column(0)
202-
.as_any()
203-
.downcast_ref::<BooleanArray>()
204-
.unwrap();
202+
let array = as_boolean_array(batches[0].column(0))?;
205203
let mut values: Vec<bool> = vec![];
206204
for i in 0..batches[0].num_rows() {
207205
values.push(array.value(i));

datafusion/core/src/datasource/file_format/parquet.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,14 @@ mod tests {
586586
use crate::physical_plan::metrics::MetricValue;
587587
use crate::prelude::{SessionConfig, SessionContext};
588588
use arrow::array::{
589-
Array, ArrayRef, BinaryArray, BooleanArray, StringArray, TimestampNanosecondArray,
589+
Array, ArrayRef, BinaryArray, StringArray, TimestampNanosecondArray,
590590
};
591591
use arrow::record_batch::RecordBatch;
592592
use async_trait::async_trait;
593593
use bytes::Bytes;
594-
use datafusion_common::cast::{as_float32_array, as_float64_array, as_int32_array};
594+
use datafusion_common::cast::{
595+
as_boolean_array, as_float32_array, as_float64_array, as_int32_array,
596+
};
595597
use datafusion_common::ScalarValue;
596598
use futures::stream::BoxStream;
597599
use futures::StreamExt;
@@ -945,11 +947,7 @@ mod tests {
945947
assert_eq!(1, batches[0].num_columns());
946948
assert_eq!(8, batches[0].num_rows());
947949

948-
let array = batches[0]
949-
.column(0)
950-
.as_any()
951-
.downcast_ref::<BooleanArray>()
952-
.unwrap();
950+
let array = as_boolean_array(batches[0].column(0))?;
953951
let mut values: Vec<bool> = vec![];
954952
for i in 0..batches[0].num_rows() {
955953
values.push(array.value(i));

datafusion/core/src/datasource/listing/helpers.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::sync::Arc;
2222
use arrow::{
2323
array::{
2424
Array, ArrayBuilder, ArrayRef, Date64Array, Date64Builder, StringBuilder,
25-
UInt64Array, UInt64Builder,
25+
UInt64Builder,
2626
},
2727
datatypes::{DataType, Field, Schema},
2828
record_batch::RecordBatch,
@@ -38,7 +38,10 @@ use crate::{
3838

3939
use super::PartitionedFile;
4040
use crate::datasource::listing::ListingTableUrl;
41-
use datafusion_common::{cast::as_string_array, Column, DataFusionError};
41+
use datafusion_common::{
42+
cast::{as_string_array, as_uint64_array},
43+
Column, DataFusionError,
44+
};
4245
use datafusion_expr::{
4346
expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion},
4447
Expr, Volatility,
@@ -300,11 +303,7 @@ fn batches_to_paths(batches: &[RecordBatch]) -> Result<Vec<PartitionedFile>> {
300303
.iter()
301304
.flat_map(|batch| {
302305
let key_array = as_string_array(batch.column(0)).unwrap();
303-
let length_array = batch
304-
.column(1)
305-
.as_any()
306-
.downcast_ref::<UInt64Array>()
307-
.unwrap();
306+
let length_array = as_uint64_array(batch.column(1)).unwrap();
308307
let modified_array = batch
309308
.column(2)
310309
.as_any()

datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use arrow::array::{Array, BooleanArray};
1919
use arrow::datatypes::{DataType, Schema};
2020
use arrow::error::{ArrowError, Result as ArrowResult};
2121
use arrow::record_batch::RecordBatch;
22+
use datafusion_common::cast::as_boolean_array;
2223
use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema};
2324
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
2425
use std::collections::BTreeSet;
@@ -134,17 +135,12 @@ impl ArrowPredicate for DatafusionArrowPredicate {
134135
.map(|v| v.into_array(batch.num_rows()))
135136
{
136137
Ok(array) => {
137-
if let Some(mask) = array.as_any().downcast_ref::<BooleanArray>() {
138-
let bool_arr = BooleanArray::from(mask.data().clone());
139-
let num_filtered = bool_arr.len() - bool_arr.true_count();
140-
self.rows_filtered.add(num_filtered);
141-
timer.stop();
142-
Ok(bool_arr)
143-
} else {
144-
Err(ArrowError::ComputeError(
145-
"Unexpected result of predicate evaluation, expected BooleanArray".to_owned(),
146-
))
147-
}
138+
let mask = as_boolean_array(&array)?;
139+
let bool_arr = BooleanArray::from(mask.data().clone());
140+
let num_filtered = bool_arr.len() - bool_arr.true_count();
141+
self.rows_filtered.add(num_filtered);
142+
timer.stop();
143+
Ok(bool_arr)
148144
}
149145
Err(e) => Err(ArrowError::ComputeError(format!(
150146
"Error evaluating filter predicate: {:?}",

datafusion/core/src/physical_plan/joins/hash_join.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use std::{time::Instant, vec};
4040

4141
use futures::{ready, Stream, StreamExt, TryStreamExt};
4242

43-
use arrow::array::{as_boolean_array, new_null_array, Array};
43+
use arrow::array::{new_null_array, Array};
4444
use arrow::datatypes::{ArrowNativeType, DataType};
4545
use arrow::datatypes::{Schema, SchemaRef};
4646
use arrow::error::Result as ArrowResult;
@@ -52,7 +52,7 @@ use arrow::array::{
5252
UInt8Array,
5353
};
5454

55-
use datafusion_common::cast::as_string_array;
55+
use datafusion_common::cast::{as_boolean_array, as_string_array};
5656

5757
use hashbrown::raw::RawTable;
5858

@@ -1027,7 +1027,7 @@ fn apply_join_filter(
10271027
.expression()
10281028
.evaluate(&intermediate_batch)?
10291029
.into_array(intermediate_batch.num_rows());
1030-
let mask = as_boolean_array(&filter_result);
1030+
let mask = as_boolean_array(&filter_result)?;
10311031

10321032
let left_filtered = PrimitiveArray::<UInt64Type>::from(
10331033
compute::filter(&left_indices, mask)?.data().clone(),
@@ -1050,7 +1050,7 @@ fn apply_join_filter(
10501050
.expression()
10511051
.evaluate_selection(&intermediate_batch, &has_match)?
10521052
.into_array(intermediate_batch.num_rows());
1053-
let mask = as_boolean_array(&filter_result);
1053+
let mask = as_boolean_array(&filter_result)?;
10541054

10551055
let mut left_rebuilt = UInt64Builder::with_capacity(0);
10561056
let mut right_rebuilt = UInt32Builder::with_capacity(0);

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ use arrow::record_batch::RecordBatch;
7575

7676
use crate::physical_expr::down_cast_any_ref;
7777
use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr};
78-
use datafusion_common::cast::as_decimal128_array;
78+
use datafusion_common::cast::{as_boolean_array, as_decimal128_array};
7979
use datafusion_common::ScalarValue;
8080
use datafusion_common::{DataFusionError, Result};
8181
use datafusion_expr::type_coercion::binary::binary_operator_data_type;
@@ -472,14 +472,8 @@ macro_rules! binary_array_op {
472472
/// Invoke a boolean kernel on a pair of arrays
473473
macro_rules! boolean_op {
474474
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
475-
let ll = $LEFT
476-
.as_any()
477-
.downcast_ref::<BooleanArray>()
478-
.expect("boolean_op failed to downcast array");
479-
let rr = $RIGHT
480-
.as_any()
481-
.downcast_ref::<BooleanArray>()
482-
.expect("boolean_op failed to downcast array");
475+
let ll = as_boolean_array($LEFT).expect("boolean_op failed to downcast array");
476+
let rr = as_boolean_array($RIGHT).expect("boolean_op failed to downcast array");
483477
Ok(Arc::new($OP(&ll, &rr)?))
484478
}};
485479
}
@@ -1003,7 +997,7 @@ impl BinaryExpr {
1003997
Operator::Modulo => binary_primitive_array_op!(left, right, modulus),
1004998
Operator::And => {
1005999
if left_data_type == &DataType::Boolean {
1006-
boolean_op!(left, right, and_kleene)
1000+
boolean_op!(&left, &right, and_kleene)
10071001
} else {
10081002
Err(DataFusionError::Internal(format!(
10091003
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
@@ -1015,7 +1009,7 @@ impl BinaryExpr {
10151009
}
10161010
Operator::Or => {
10171011
if left_data_type == &DataType::Boolean {
1018-
boolean_op!(left, right, or_kleene)
1012+
boolean_op!(&left, &right, or_kleene)
10191013
} else {
10201014
Err(DataFusionError::Internal(format!(
10211015
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
@@ -1110,10 +1104,8 @@ mod tests {
11101104
assert_eq!(result.len(), 5);
11111105

11121106
let expected = vec![false, false, true, true, true];
1113-
let result = result
1114-
.as_any()
1115-
.downcast_ref::<BooleanArray>()
1116-
.expect("failed to downcast to BooleanArray");
1107+
let result =
1108+
as_boolean_array(&result).expect("failed to downcast to BooleanArray");
11171109
for (i, &expected_item) in expected.iter().enumerate().take(5) {
11181110
assert_eq!(result.value(i), expected_item);
11191111
}
@@ -1156,10 +1148,8 @@ mod tests {
11561148
assert_eq!(result.len(), 5);
11571149

11581150
let expected = vec![true, true, false, true, false];
1159-
let result = result
1160-
.as_any()
1161-
.downcast_ref::<BooleanArray>()
1162-
.expect("failed to downcast to BooleanArray");
1151+
let result =
1152+
as_boolean_array(&result).expect("failed to downcast to BooleanArray");
11631153
for (i, &expected_item) in expected.iter().enumerate().take(5) {
11641154
assert_eq!(result.value(i), expected_item);
11651155
}

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::compute::kernels::zip::zip;
2626
use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene};
2727
use arrow::datatypes::{DataType, Schema};
2828
use arrow::record_batch::RecordBatch;
29-
use datafusion_common::{DataFusionError, Result};
29+
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result};
3030
use datafusion_expr::ColumnarValue;
3131

3232
use itertools::Itertools;
@@ -195,10 +195,7 @@ impl CaseExpr {
195195
_ => when_value,
196196
};
197197
let when_value = when_value.into_array(batch.num_rows());
198-
let when_value = when_value
199-
.as_ref()
200-
.as_any()
201-
.downcast_ref::<BooleanArray>()
198+
let when_value = as_boolean_array(&when_value)
202199
.expect("WHEN expression did not return a BooleanArray");
203200

204201
let then_value = self.when_then_expr[i]

0 commit comments

Comments
 (0)