Skip to content

Commit

Permalink
perf: Optimize decimal precision check in decimal aggregates (sum and…
Browse files Browse the repository at this point in the history
… avg) (#952)

* agg bench

* fix

* fix

* refactor

* avg

* optimized decimal aggregates with more efficient version of validate_decimal_precision

* simplify function to remove branch

* address feedback

* format

* Revert a change

* add rust unit test

* format

* code cleanup

* fix

* fix

* fix

* fmt

* update bench

* clippy
  • Loading branch information
andygrove authored Sep 24, 2024
1 parent 5b3f7bc commit 50517f6
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 84 deletions.
9 changes: 3 additions & 6 deletions native/core/benches/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("avg_decimal_comet", |b| {
let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new(
Arc::clone(&c1),
"avg",
DataType::Decimal128(38, 10),
DataType::Decimal128(38, 10),
)));
Expand Down Expand Up @@ -96,11 +95,9 @@ fn criterion_benchmark(c: &mut Criterion) {
});

group.bench_function("sum_decimal_comet", |b| {
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new(
"sum",
Arc::clone(&c1),
DataType::Decimal128(38, 10),
)));
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
SumDecimal::try_new(Arc::clone(&c1), DataType::Decimal128(38, 10)).unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
partitions,
Expand Down
27 changes: 9 additions & 18 deletions native/core/src/execution/datafusion/expressions/avg_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr};
use std::{any::Any, sync::Arc};

use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
use arrow_array::ArrowNativeTypeOp;
use arrow_data::decimal::{
validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use datafusion::logical_expr::Volatility::Immutable;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
Expand All @@ -43,7 +42,6 @@ use DataType::*;
/// AVG aggregate expression
#[derive(Debug, Clone)]
pub struct AvgDecimal {
name: String,
signature: Signature,
expr: Arc<dyn PhysicalExpr>,
sum_data_type: DataType,
Expand All @@ -52,14 +50,8 @@ pub struct AvgDecimal {

impl AvgDecimal {
/// Create a new AVG aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
result_type: DataType,
sum_type: DataType,
) -> Self {
pub fn new(expr: Arc<dyn PhysicalExpr>, result_type: DataType, sum_type: DataType) -> Self {
Self {
name: name.into(),
signature: Signature::user_defined(Immutable),
expr,
result_data_type: result_type,
Expand Down Expand Up @@ -95,20 +87,20 @@ impl AggregateUDFImpl for AvgDecimal {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new(
format_state_name(&self.name, "sum"),
format_state_name(self.name(), "sum"),
self.sum_data_type.clone(),
true,
),
Field::new(
format_state_name(&self.name, "count"),
format_state_name(self.name(), "count"),
DataType::Int64,
true,
),
])
}

fn name(&self) -> &str {
&self.name
"avg"
}

fn reverse_expr(&self) -> ReversedUDAF {
Expand Down Expand Up @@ -169,8 +161,7 @@ impl PartialEq<dyn Any> for AvgDecimal {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.sum_data_type == x.sum_data_type
self.sum_data_type == x.sum_data_type
&& self.result_data_type == x.result_data_type
&& self.expr.eq(&x.expr)
})
Expand Down Expand Up @@ -212,7 +203,7 @@ impl AvgDecimalAccumulator {
None => (v, false),
};

if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
// Overflow: set buffer accumulator to null
self.is_not_null = false;
return;
Expand Down Expand Up @@ -380,7 +371,7 @@ impl AvgDecimalGroupsAccumulator {
let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value);
self.counts[group_index] += 1;

if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) {
// Overflow: set buffer accumulator to null
self.is_not_null.set_bit(group_index, false);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use arrow::{
datatypes::{Decimal128Type, DecimalType},
record_batch::RecordBatch,
};
use arrow_schema::{DataType, Schema};
use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_schema::{DataType, Schema, DECIMAL128_MAX_PRECISION};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{DataFusionError, ScalarValue};
Expand Down Expand Up @@ -171,3 +172,15 @@ impl PhysicalExpr for CheckOverflow {
self.hash(&mut s);
}
}

/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
/// instead of Err to avoid the cost of formatting the error strings and is
/// optimized to remove a memcpy that exists in the original function
/// we can remove this code once we upgrade to a version of arrow-rs that
/// includes https://github.com/apache/arrow-rs/pull/6419
#[inline]
pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
precision <= DECIMAL128_MAX_PRECISION
&& value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
&& value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
}
151 changes: 125 additions & 26 deletions native/core/src/execution/datafusion/expressions/sum_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
use crate::unlikely;
use arrow::{
array::BooleanBufferBuilder,
Expand All @@ -23,11 +24,10 @@ use arrow::{
use arrow_array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
};
use arrow_data::decimal::validate_decimal_precision;
use arrow_schema::{DataType, Field};
use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{Result as DFResult, ScalarValue};
use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature};
Expand All @@ -36,37 +36,37 @@ use std::{any::Any, ops::BitAnd, sync::Arc};

#[derive(Debug)]
pub struct SumDecimal {
name: String,
/// Aggregate function signature
signature: Signature,
/// The expression that provides the input decimal values to be summed
expr: Arc<dyn PhysicalExpr>,

/// The data type of the SUM result
/// The data type of the SUM result. This will always be a decimal type
/// with the same precision and scale as specified in this struct
result_type: DataType,

/// Decimal precision and scale
/// Decimal precision
precision: u8,
/// Decimal scale
scale: i8,

/// Whether the result is nullable
nullable: bool,
}

impl SumDecimal {
pub fn new(name: impl Into<String>, expr: Arc<dyn PhysicalExpr>, data_type: DataType) -> Self {
pub fn try_new(expr: Arc<dyn PhysicalExpr>, data_type: DataType) -> DFResult<Self> {
// The `data_type` is the SUM result type passed from Spark side
let (precision, scale) = match data_type {
DataType::Decimal128(p, s) => (p, s),
_ => unreachable!(),
_ => {
return Err(DataFusionError::Internal(
"Invalid data type for SumDecimal".into(),
))
}
};
Self {
name: name.into(),
Ok(Self {
signature: Signature::user_defined(Immutable),
expr,
result_type: data_type,
precision,
scale,
nullable: true,
}
})
}
}

Expand All @@ -84,14 +84,14 @@ impl AggregateUDFImpl for SumDecimal {

fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<Field>> {
let fields = vec![
Field::new(&self.name, self.result_type.clone(), self.nullable),
Field::new(self.name(), self.result_type.clone(), self.is_nullable()),
Field::new("is_empty", DataType::Boolean, false),
];
Ok(fields)
}

fn name(&self) -> &str {
&self.name
"sum"
}

fn signature(&self) -> &Signature {
Expand Down Expand Up @@ -127,19 +127,22 @@ impl AggregateUDFImpl for SumDecimal {
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}

fn is_nullable(&self) -> bool {
// SumDecimal is always nullable because overflows can cause null values
true
}
}

impl PartialEq<dyn Any> for SumDecimal {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.precision == x.precision
&& self.scale == x.scale
&& self.nullable == x.nullable
&& self.result_type == x.result_type
&& self.expr.eq(&x.expr)
// note that we do not compare result_type because this
// is guaranteed to match if the precision and scale
// match
self.precision == x.precision && self.scale == x.scale && self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
Expand Down Expand Up @@ -170,7 +173,7 @@ impl SumDecimalAccumulator {
let v = unsafe { values.value_unchecked(idx) };
let (new_sum, is_overflow) = self.sum.overflowing_add(v);

if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
// Overflow: set buffer accumulator to null
self.is_not_null = false;
return;
Expand Down Expand Up @@ -312,7 +315,7 @@ impl SumDecimalGroupsAccumulator {
self.is_empty.set_bit(group_index, false);
let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value);

if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() {
if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) {
// Overflow: set buffer accumulator to null
self.is_not_null.set_bit(group_index, false);
return;
Expand Down Expand Up @@ -478,3 +481,99 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
+ self.is_not_null.capacity() / 8
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::*;
use arrow_array::builder::{Decimal128Builder, StringBuilder};
use arrow_array::RecordBatch;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::Result;
use datafusion_execution::TaskContext;
use datafusion_expr::AggregateUDF;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{Column, Literal};
use futures::StreamExt;

#[test]
fn invalid_data_type() {
let expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
assert!(SumDecimal::try_new(expr, DataType::Int32).is_err());
}

#[tokio::test]
async fn sum_no_overflow() -> Result<()> {
let num_rows = 8192;
let batch = create_record_batch(num_rows);
let mut batches = Vec::new();
for _ in 0..10 {
batches.push(batch.clone());
}
let partitions = &[batches];
let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));

let data_type = DataType::Decimal128(8, 2);
let schema = Arc::clone(&partitions[0][0].schema());
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap());

let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
Arc::clone(&c1),
data_type.clone(),
)?));

let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
.schema(Arc::clone(&schema))
.alias("sum")
.with_ignore_nulls(false)
.with_distinct(false)
.build()?;

let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
vec![aggr_expr],
vec![None], // no filter expressions
scan,
Arc::clone(&schema),
)?);

let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch?;
}

Ok(())
}

fn create_record_batch(num_rows: usize) -> RecordBatch {
let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
for i in 0..num_rows {
decimal_builder.append_value(i as i128);
string_builder.append_value(format!("this is string #{}", i % 1024));
}
let decimal_array = Arc::new(decimal_builder.finish());
let string_array = Arc::new(string_builder.finish());

let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];

// string column
fields.push(Field::new("c0", DataType::Utf8, false));
columns.push(string_array);

// decimal column
fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
columns.push(decimal_array);

let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}
}
Loading

0 comments on commit 50517f6

Please sign in to comment.