Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement bloom_filter_agg #987

Merged
merged 28 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e662814
Add test that invokes bloom_filter_agg.
mbutrovich Sep 25, 2024
20f6e67
QueryPlanSerde support for BloomFilterAgg.
mbutrovich Sep 25, 2024
1ec31a2
Add bloom_filter_agg based on sample UDAF. planner instantiates it no…
mbutrovich Sep 27, 2024
3965dc4
Partial work on Accumulator. Need to finish merge_batch and state.
mbutrovich Sep 27, 2024
62e656c
BloomFilterAgg state, merge_state, and evaluate. Need more tests.
mbutrovich Sep 30, 2024
33ef47d
Matches Spark behavior. Need to clean up the code quite a bit, and do…
mbutrovich Sep 30, 2024
2040c76
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Sep 30, 2024
599a8f9
Remove old comment.
mbutrovich Sep 30, 2024
a2a8cf3
Clippy. Increase bloom filter size back to Spark's default.
mbutrovich Sep 30, 2024
22aedd9
API cleanup.
mbutrovich Sep 30, 2024
bf22902
API cleanup.
mbutrovich Oct 1, 2024
4b7000c
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 2, 2024
88adc75
Add BloomFilterAgg benchmark to CometExecBenchmark
mbutrovich Oct 2, 2024
a21e0e3
Docs.
mbutrovich Oct 2, 2024
5c5d0f9
API cleanup, fix merge_bits to update cardinality.
mbutrovich Oct 2, 2024
cd107e3
Refactor merge_bits to update bit_count with the bit merging.
mbutrovich Oct 2, 2024
4f06098
Remove benchmark results file.
mbutrovich Oct 2, 2024
79f6468
Docs.
mbutrovich Oct 2, 2024
57fe742
Add native side benchmarks.
mbutrovich Oct 2, 2024
ec64e4c
Adjust benchmark parameters to match Spark defaults.
mbutrovich Oct 2, 2024
7a81f35
Address review feedback.
mbutrovich Oct 2, 2024
013513e
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 3, 2024
3347923
Add assertion to merge_batch.
mbutrovich Oct 4, 2024
5c82f24
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 9, 2024
c39ff1d
Merge branch 'apache:main' into bloom_field_agg
mbutrovich Oct 13, 2024
1ed99e3
Address some review feedback.
mbutrovich Oct 17, 2024
d41a9d2
Only generate native BloomFilterAgg if child has LongType.
mbutrovich Oct 18, 2024
6d13890
Add TODO with GitHub issue link.
mbutrovich Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Field;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use std::{any::Any, sync::Arc};

use crate::execution::datafusion::util::spark_bloom_filter;
use crate::execution::datafusion::util::spark_bloom_filter::SparkBloomFilter;
use arrow::array::ArrayRef;
use arrow_array::BinaryArray;
use datafusion::error::Result;
use datafusion::physical_expr::PhysicalExpr;
use datafusion_common::{downcast_value, DataFusionError, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDFImpl, Signature,
};
use datafusion_physical_expr::expressions::Literal;

#[derive(Debug, Clone)]
pub struct BloomFilterAgg {
name: String,
signature: Signature,
expr: Arc<dyn PhysicalExpr>,
num_items: i32,
num_bits: i32,
}

fn extract_i32_from_literal(expr: Arc<dyn PhysicalExpr>) -> i32 {
mbutrovich marked this conversation as resolved.
Show resolved Hide resolved
match expr.as_any().downcast_ref::<Literal>().unwrap().value() {
ScalarValue::Int64(scalar_value) => scalar_value.unwrap() as i32,
_ => {
unreachable!()
}
}
}

impl BloomFilterAgg {
pub fn new(
expr: Arc<dyn PhysicalExpr>,
num_items: Arc<dyn PhysicalExpr>,
num_bits: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Self {
assert!(matches!(data_type, DataType::Binary));
Self {
name: name.into(),
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
expr,
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
}
}
}

impl AggregateUDFImpl for BloomFilterAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"bloom_filter_agg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SparkBloomFilter::from((
spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits),
self.num_bits,
))))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new("bits", DataType::Binary, false)])
}

fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}
}

impl Accumulator for SparkBloomFilter {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Int64(Some(value)) = v {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only supports Int64? Spark BloomFilterAggregate supports Byte, Short, Int, Long and String. If Comet BloomFilterAggregate only support Int64 for now. We need to fallback to Spark for other cases in QueryPlanSerde.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was going off of their docs which say it only supports Long.

In their implementation, however, if looks like they can cast the fixed width types directly to Long
https://github.com/apache/spark/blob/b078c0d6e2adf7eb0ee7d4742a6c52864440226e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala#L238

and for strings their bloom filter implementation has a putBinary method that we don't currently support. The casts should be easy. I'll look at what putBinary on our bloom filter implementation will take.

Copy link
Contributor Author

@mbutrovich mbutrovich Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what happened. 3.4 only supports Long, which was the Spark source I was working off of. 3.5 added support for other types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified it to only generate a native BloomFilterAgg if the child has LongType. I'll open an issue to support more types in the future.

self.put_long(value);
} else {
unreachable!("")
mbutrovich marked this conversation as resolved.
Show resolved Hide resolved
}
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Binary(Some(self.spark_serialization())))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
// TODO(Matt): There might be a more efficient way to do this by transmuting since calling
// state() on an Accumulator is considered destructive.
let state_sv = ScalarValue::Binary(Some(self.state_as_bytes()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to avoid the copy, which may be too ugly , would be to store bloom filter data as an Option<>

So instead of

pub struct SparkBloomFilter {
    bits: SparkBitArray,
    num_hash_functions: u32,
}

Something like

pub struct SparkBloomFilter {
    bits: Option<SparkBitArray>
    num_hash_functions: u32,
}

And then you could basically use Option::take to take the value and leave a None in its place

let Some(bits) = self.bits.take() else {
  return Err(invalid state)
};

// do whatever you want now you have the owned `bits`

Ok(vec![state_sv])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let state_sv = downcast_value!(states[0], BinaryArray);
self.merge_filter(state_sv.value_data());
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn evaluate_bloom_filter(
let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
match bloom_filter_bytes {
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
Ok(v.map(|v| SparkBloomFilter::from(v.as_bytes())))
}
_ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"),
}
Expand Down
1 change: 1 addition & 0 deletions native/core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
use crate::errors::CometError;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_agg;
pub mod bloom_filter_might_contain;
pub mod comet_scalar_funcs;
pub mod correlation;
Expand Down
17 changes: 17 additions & 0 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
avg::Avg,
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_agg::BloomFilterAgg,
bloom_filter_might_contain::BloomFilterMightContain,
checkoverflow::CheckOverflow,
correlation::Correlation,
Expand Down Expand Up @@ -1611,6 +1612,22 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("correlation", schema, vec![child1, child2], func)
}
AggExprStruct::BloomFilterAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let num_items =
self.create_expr(expr.num_items.as_ref().unwrap(), Arc::clone(&schema))?;
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&child),
Arc::clone(&num_items),
Arc::clone(&num_bits),
"bloom_filter_agg",
datatype,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
}
}

Expand Down
103 changes: 102 additions & 1 deletion native/core/src/execution/datafusion/util/spark_bit_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use arrow_buffer::ToByteSlice;
use std::iter::zip;

/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is
/// used in the BloomFilter implementation. Some methods are not implemented as they are not
/// required for the current use case.
Expand Down Expand Up @@ -55,12 +58,48 @@ impl SparkBitArray {
}

pub fn bit_size(&self) -> u64 {
self.data.len() as u64 * 64
self.word_size() as u64 * 64
}

pub fn byte_size(&self) -> usize {
self.word_size() * 8
}

pub fn word_size(&self) -> usize {
self.data.len()
}

pub fn cardinality(&self) -> usize {
self.bit_count
}

pub fn to_bytes(&self) -> Vec<u8> {
Vec::from(self.data.to_byte_slice())
}

pub fn to_bytes_not_vec(&self) -> &[u8] {
self.data.to_byte_slice()
}

pub fn data(&self) -> Vec<u64> {
self.data.clone()
}

pub fn merge_bits(&mut self, other: &[u8]) {
assert_eq!(self.byte_size(), other.len());
for i in zip(
self.data.iter_mut(),
other
.chunks(8)
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
) {
*i.0 |= i.1;
}
}
}

pub fn num_words(num_bits: i32) -> i32 {
(num_bits as f64 / 64.0).ceil() as i32
mbutrovich marked this conversation as resolved.
Show resolved Hide resolved
}

#[cfg(test)]
Expand Down Expand Up @@ -128,4 +167,66 @@ mod test {
// check cardinality
assert_eq!(array.cardinality(), 6);
}

#[test]
fn test_spark_bit_with_empty_buffer() {
let buf = vec![0u64; 4];
let array = SparkBitArray::new(buf);

assert_eq!(array.bit_size(), 256);
assert_eq!(array.cardinality(), 0);

for n in 0..256 {
assert!(!array.get(n));
}
}

#[test]
fn test_spark_bit_with_full_buffer() {
let buf = vec![u64::MAX; 4];
let array = SparkBitArray::new(buf);

assert_eq!(array.bit_size(), 256);
assert_eq!(array.cardinality(), 256);

for n in 0..256 {
assert!(array.get(n));
}
}

#[test]
fn test_spark_bit_merge() {
let buf1 = vec![0u64; 4];
let mut array1 = SparkBitArray::new(buf1);
let buf2 = vec![0u64; 4];
let mut array2 = SparkBitArray::new(buf2);

let primes = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83,
89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179,
181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251,
];
let fibs = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233];

for n in fibs {
array1.set(n);
}

for n in primes {
array2.set(n);
}

assert_eq!(array1.cardinality(), fibs.len());
assert_eq!(array2.cardinality(), primes.len());

array1.merge_bits(array2.to_bytes_not_vec());

for n in fibs {
assert!(array1.get(n));
}

for n in primes {
assert!(array1.get(n));
}
}
}
Loading
Loading