Skip to content

Commit

Permalink
Fix accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 13, 2023
1 parent 2f04527 commit a27881f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 102 deletions.
68 changes: 18 additions & 50 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1933,101 +1933,69 @@ CREATE TABLE test_table (c1 INT, c2 INT, c3 INT)

# Inserting data
statement ok
INSERT INTO test_table VALUES
(1, 10, 50),
(1, 20, 60),
(2, 10, 70),
(2, 20, 80),
(3, 10, NULL)
INSERT INTO test_table VALUES (1, 10, 50), (1, 20, 60), (2, 10, 70), (2, 20, 80), (3, 10, NULL)

# query_group_by_with_filter
query III rowsort
SELECT
c1,
SUM(c2) FILTER (WHERE c2 >= 20),
SUM(c2) FILTER (WHERE c2 < 1) -- no rows pass filter, so the output should be NULL
FROM test_table GROUP BY c1
query II rowsort
SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test_table GROUP BY c1
----
1 20 NULL
2 20 NULL
3 NULL NULL
1 20
2 20
3 NULL

# query_group_by_avg_with_filter
query IRR rowsort
SELECT
c1,
AVG(c2) FILTER (WHERE c2 >= 20),
AVG(c2) FILTER (WHERE c2 < 1) -- no rows pass filter, so output should be null
FROM test_table GROUP BY c1
query IR rowsort
SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test_table GROUP BY c1
----
1 20 NULL
2 20 NULL
3 NULL NULL
1 20
2 20
3 NULL

# query_group_by_with_multiple_filters
query IIR rowsort
SELECT
c1,
SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2,
AVG(c3) FILTER (WHERE c3 <= 70) AS avg_c3
FROM test_table GROUP BY c1
SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) FILTER (WHERE c3 <= 70) AS avg_c3 FROM test_table GROUP BY c1
----
1 20 55
2 20 70
3 NULL NULL

# query_group_by_distinct_with_filter
query II rowsort
SELECT
c1,
COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count
FROM test_table GROUP BY c1
SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count FROM test_table GROUP BY c1
----
1 1
2 1
3 0

# query_without_group_by_with_filter
query I rowsort
SELECT
SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2
FROM test_table
SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test_table
----
40

# count_without_group_by_with_filter
query I rowsort
SELECT
COUNT(c2) FILTER (WHERE c2 >= 20) AS count_c2
FROM test_table
SELECT COUNT(c2) FILTER (WHERE c2 >= 20) AS count_c2 FROM test_table
----
2

# query_with_and_without_filter
query III rowsort
SELECT
c1,
SUM(c2) FILTER (WHERE c2 >= 20) as result,
SUM(c2) as result_no_filter
FROM test_table GROUP BY c1;
SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result, SUM(c2) as result_no_filter FROM test_table GROUP BY c1;
----
1 20 30
2 20 30
3 NULL 10

# query_filter_on_different_column_than_aggregate
query I rowsort
select
sum(c1) FILTER (WHERE c2 < 30)
FROM test_table;
select sum(c1) FILTER (WHERE c2 < 30) from test_table;
----
9

# query_test_empty_filter
query I rowsort
SELECT
SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2
FROM test_table;
SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table;
----
NULL

Expand Down
83 changes: 32 additions & 51 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor};
use datafusion_row::accessor::RowAccessor;

/// Creates a [`PrimitiveGroupsAccumulator`] with the specified
/// [`ArrowPrimitiveType`] which applies `$FN` to each element
/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
/// and applies `$FN` to each element
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
macro_rules! instantiate_primitive_accumulator {
($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{
Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
&$SELF.data_type,
$FN,
)))
macro_rules! instantiate_accumulator {
($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, $FN)
.with_starting_value($START),
))
}};
}

Expand Down Expand Up @@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd {
use std::ops::BitAndAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int8Type, |x, y| x.bitand_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int16Type, |x, y| x.bitand_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int32Type, |x, y| x.bitand_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int64Type, |x, y| x.bitand_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
.bitand_assign(y))
}

Expand Down Expand Up @@ -517,36 +514,28 @@ impl AggregateExpr for BitOr {
use std::ops::BitOrAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitor_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitor_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitor_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitor_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitor_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitor_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitor_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitor_assign(y))
}

_ => Err(DataFusionError::NotImplemented(format!(
Expand Down Expand Up @@ -756,36 +745,28 @@ impl AggregateExpr for BitXor {
use std::ops::BitXorAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitxor_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitxor_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitxor_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitxor_assign(y))
}

_ => Err(DataFusionError::NotImplemented(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ where
/// The output type (needed for Decimal precision and scale)
data_type: DataType,

/// The starting value for new groups
starting_value: T::Native,

/// Track nulls in the input / filters
null_state: NullState,

Expand All @@ -64,9 +67,16 @@ where
values: vec![],
data_type: data_type.clone(),
null_state: NullState::new(),
starting_value: T::default_value(),
prim_fn,
}
}

/// Set the starting values for new groups
pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
self.starting_value = starting_value;
self
}
}

impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
Expand All @@ -85,7 +95,7 @@ where
let values = values[0].as_primitive::<T>();

// update values
self.values.resize(total_num_groups, T::default_value());
self.values.resize(total_num_groups, self.starting_value);

// NullState dispatches / handles tracking nulls and groups that saw no values
self.null_state.accumulate(
Expand Down

0 comments on commit a27881f

Please sign in to comment.