Skip to content

Commit

Permalink
feat: Implement bloom_filter_agg (#987)
Browse files Browse the repository at this point in the history
* Add test that invokes bloom_filter_agg.

* QueryPlanSerde support for BloomFilterAgg.

* Add bloom_filter_agg based on sample UDAF. planner instantiates it now. Added spark_bit_array_tests.

* Partial work on Accumulator. Need to finish merge_batch and state.

* BloomFilterAgg state, merge_state, and evaluate. Need more tests.

* Matches Spark behavior. Need to clean up the code quite a bit, and do `cargo clippy`.

* Remove old comment.

* Clippy. Increase bloom filter size back to Spark's default.

* API cleanup.

* API cleanup.

* Add BloomFilterAgg benchmark to CometExecBenchmark

* Docs.

* API cleanup, fix merge_bits to update cardinality.

* Refactor merge_bits to update bit_count with the bit merging.

* Remove benchmark results file.

* Docs.

* Add native side benchmarks.

* Adjust benchmark parameters to match Spark defaults.

* Address review feedback.

* Add assertion to merge_batch.

* Address some review feedback.

* Only generate native BloomFilterAgg if child has LongType.

* Add TODO with GitHub issue link.
  • Loading branch information
mbutrovich authored Oct 18, 2024
1 parent 8d097d5 commit e3ac6cf
Show file tree
Hide file tree
Showing 12 changed files with 637 additions and 20 deletions.
4 changes: 4 additions & 0 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,7 @@ harness = false
[[bench]]
name = "aggregate"
harness = false

[[bench]]
name = "bloom_filter_agg"
harness = false
162 changes: 162 additions & 0 deletions native/core/benches/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder};

use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::builder::Int64Builder;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::SchemaRef;
use comet::execution::datafusion::expressions::bloom_filter_agg::BloomFilterAgg;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::ScalarValue;
use datafusion_execution::TaskContext;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{Column, Literal};
use futures::StreamExt;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Runtime;

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("bloom_filter_agg");
let num_rows = 8192;
let batch = create_record_batch(num_rows);
let mut batches = Vec::new();
for _ in 0..10 {
batches.push(batch.clone());
}
let partitions = &[batches];
let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
// spark.sql.optimizer.runtime.bloomFilter.expectedNumItems
let num_items_sv = ScalarValue::Int64(Some(1000000_i64));
let num_items: Arc<dyn PhysicalExpr> = Arc::new(Literal::new(num_items_sv));
//spark.sql.optimizer.runtime.bloomFilter.numBits
let num_bits_sv = ScalarValue::Int64(Some(8388608_i64));
let num_bits: Arc<dyn PhysicalExpr> = Arc::new(Literal::new(num_bits_sv));

let rt = Runtime::new().unwrap();

for agg_mode in [
("partial_agg", AggregateMode::Partial),
("single_agg", AggregateMode::Single),
] {
group.bench_function(agg_mode.0, |b| {
let comet_bloom_filter_agg =
Arc::new(AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&c0),
Arc::clone(&num_items),
Arc::clone(&num_bits),
"bloom_filter_agg",
DataType::Binary,
)));
b.to_async(&rt).iter(|| {
black_box(agg_test(
partitions,
c0.clone(),
comet_bloom_filter_agg.clone(),
"bloom_filter_agg",
agg_mode.1,
))
})
});
}

group.finish();
}

async fn agg_test(
partitions: &[Vec<RecordBatch>],
c0: Arc<dyn PhysicalExpr>,
aggregate_udf: Arc<AggregateUDF>,
alias: &str,
mode: AggregateMode,
) {
let schema = &partitions[0][0].schema();
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap());
let aggregate = create_aggregate(scan, c0.clone(), schema, aggregate_udf, alias, mode);
let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch.unwrap();
}
}

fn create_aggregate(
scan: Arc<dyn ExecutionPlan>,
c0: Arc<dyn PhysicalExpr>,
schema: &SchemaRef,
aggregate_udf: Arc<AggregateUDF>,
alias: &str,
mode: AggregateMode,
) -> Arc<AggregateExec> {
let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c0.clone()])
.schema(schema.clone())
.alias(alias)
.with_ignore_nulls(false)
.with_distinct(false)
.build()
.unwrap();

Arc::new(
AggregateExec::try_new(
mode,
PhysicalGroupBy::new_single(vec![]),
vec![aggr_expr],
vec![None],
scan,
Arc::clone(schema),
)
.unwrap(),
)
}

fn create_record_batch(num_rows: usize) -> RecordBatch {
let mut int64_builder = Int64Builder::with_capacity(num_rows);
for i in 0..num_rows {
int64_builder.append_value(i as i64);
}
let int64_array = Arc::new(int64_builder.finish());

let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];

// int64 column
fields.push(Field::new("c0", DataType::Int64, false));
columns.push(int64_array);

let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}

fn config() -> Criterion {
Criterion::default()
.measurement_time(Duration::from_millis(500))
.warm_up_time(Duration::from_millis(500))
}

criterion_group! {
name = benches;
config = config();
targets = criterion_benchmark
}
criterion_main!(benches);
151 changes: 151 additions & 0 deletions native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Field;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use std::{any::Any, sync::Arc};

use crate::execution::datafusion::util::spark_bloom_filter;
use crate::execution::datafusion::util::spark_bloom_filter::SparkBloomFilter;
use arrow::array::ArrayRef;
use arrow_array::BinaryArray;
use datafusion::error::Result;
use datafusion::physical_expr::PhysicalExpr;
use datafusion_common::{downcast_value, DataFusionError, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDFImpl, Signature,
};
use datafusion_physical_expr::expressions::Literal;

#[derive(Debug, Clone)]
pub struct BloomFilterAgg {
name: String,
signature: Signature,
expr: Arc<dyn PhysicalExpr>,
num_items: i32,
num_bits: i32,
}

#[inline]
fn extract_i32_from_literal(expr: Arc<dyn PhysicalExpr>) -> i32 {
match expr.as_any().downcast_ref::<Literal>().unwrap().value() {
ScalarValue::Int64(scalar_value) => scalar_value.unwrap() as i32,
_ => {
unreachable!()
}
}
}

impl BloomFilterAgg {
pub fn new(
expr: Arc<dyn PhysicalExpr>,
num_items: Arc<dyn PhysicalExpr>,
num_bits: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
assert!(matches!(data_type, DataType::Binary));
Self {
name: name.into(),
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
expr,
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
}
}
}

impl AggregateUDFImpl for BloomFilterAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"bloom_filter_agg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SparkBloomFilter::from((
spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits),
self.num_bits,
))))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new("bits", DataType::Binary, false)])
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}
}

impl Accumulator for SparkBloomFilter {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Int64(Some(value)) = v {
self.put_long(value);
} else {
unreachable!()
}
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Binary(Some(self.spark_serialization())))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
// There might be a more efficient way to do this by transmuting since calling state() on an
// Accumulator is considered destructive.
let state_sv = ScalarValue::Binary(Some(self.state_as_bytes()));
Ok(vec![state_sv])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
assert_eq!(
states.len(),
1,
"Expect one element in 'states' but found {}",
states.len()
);
assert_eq!(states[0].len(), 1);
let state_sv = downcast_value!(states[0], BinaryArray);
self.merge_filter(state_sv.value_data());
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn evaluate_bloom_filter(
let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
match bloom_filter_bytes {
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
Ok(v.map(|v| SparkBloomFilter::from(v.as_bytes())))
}
_ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"),
}
Expand Down
1 change: 1 addition & 0 deletions native/core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
use crate::errors::CometError;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_agg;
pub mod bloom_filter_might_contain;
pub mod comet_scalar_funcs;
pub mod correlation;
Expand Down
17 changes: 17 additions & 0 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
avg::Avg,
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_agg::BloomFilterAgg,
bloom_filter_might_contain::BloomFilterMightContain,
checkoverflow::CheckOverflow,
correlation::Correlation,
Expand Down Expand Up @@ -1620,6 +1621,22 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func)
}
AggExprStruct::BloomFilterAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let num_items =
self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?;
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&child),
Arc::clone(&num_items),
Arc::clone(&num_bits),
"bloom_filter_agg",
datatype,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
}
}

Expand Down
Loading

0 comments on commit e3ac6cf

Please sign in to comment.