Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
let mut filter_builder = FilterBuilder::new(predicate);

if multiple_arrays(values.data_type()) {
if FilterBuilder::is_optimize_beneficial(values.data_type()) {
// Only optimize if filtering more than one array
// Otherwise, the overhead of optimization can be more than the benefit
filter_builder = filter_builder.optimize();
Expand All @@ -166,16 +166,6 @@ pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef,
filter_array(values, &predicate)
}

fn multiple_arrays(data_type: &DataType) -> bool {
match data_type {
DataType::Struct(fields) => {
fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
}
DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
_ => false,
}
}

/// Returns a filtered [RecordBatch] where the corresponding elements of
/// `predicate` are true.
///
Expand All @@ -193,7 +183,10 @@ pub fn filter_record_batch(
let mut filter_builder = FilterBuilder::new(predicate);
let num_cols = record_batch.num_columns();
if num_cols > 1
|| (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type()))
|| (num_cols > 0
&& FilterBuilder::is_optimize_beneficial(
record_batch.schema_ref().field(0).data_type(),
))
{
// Only optimize if filtering more than one column or if the column contains multiple internal arrays
// Otherwise, the overhead of optimization can be more than the benefit
Expand Down Expand Up @@ -230,7 +223,7 @@ impl FilterBuilder {
}
}

/// Compute an optimised representation of the provided `filter` mask that can be
/// Compute an optimized representation of the provided `filter` mask that can be
/// applied to an array more quickly.
///
/// Note: There is limited benefit to calling this to then filter a single array
Expand All @@ -250,6 +243,20 @@ impl FilterBuilder {
self
}

/// Determines if calling [FilterBuilder::optimize] is beneficial for the given type even when
/// filtering just a single array.
pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
match data_type {
DataType::Struct(fields) => {
fields.len() > 1
|| fields.len() == 1
&& FilterBuilder::is_optimize_beneficial(fields[0].data_type())
}
DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
_ => false,
}
}

/// Construct the final `FilterPredicate`
pub fn build(self) -> FilterPredicate {
FilterPredicate {
Expand Down
Loading