diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index e6b3e3155..14425f76c 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -67,7 +67,6 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("avg_decimal_comet", |b| { let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new( Arc::clone(&c1), - "avg", DataType::Decimal128(38, 10), DataType::Decimal128(38, 10), ))); @@ -96,11 +95,9 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.bench_function("sum_decimal_comet", |b| { - let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( - "sum", - Arc::clone(&c1), - DataType::Decimal128(38, 10), - ))); + let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl( + SumDecimal::try_new(Arc::clone(&c1), DataType::Decimal128(38, 10)).unwrap(), + )); b.to_async(&rt).iter(|| { black_box(agg_test( partitions, diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs b/native/core/src/execution/datafusion/expressions/avg_decimal.rs index 0462f2d3d..a265fdc29 100644 --- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -28,10 +28,9 @@ use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr}; use std::{any::Any, sync::Arc}; +use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision; use arrow_array::ArrowNativeTypeOp; -use arrow_data::decimal::{ - validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, -}; +use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -43,7 +42,6 @@ use DataType::*; /// AVG aggregate expression #[derive(Debug, Clone)] pub struct AvgDecimal { - name: String, signature: Signature, expr: Arc, sum_data_type: DataType, @@ -52,14 +50,8 @@ pub struct AvgDecimal { impl AvgDecimal { /// Create a new AVG aggregate function - pub fn new( - expr: Arc, - name: impl Into, - result_type: DataType, - sum_type: DataType, - ) -> Self { + pub fn new(expr: Arc, result_type: DataType, sum_type: DataType) -> Self { Self { - name: name.into(), signature: Signature::user_defined(Immutable), expr, result_data_type: result_type, @@ -95,12 +87,12 @@ impl AggregateUDFImpl for AvgDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "sum"), + format_state_name(self.name(), "sum"), self.sum_data_type.clone(), true, ), Field::new( - format_state_name(&self.name, "count"), + format_state_name(self.name(), "count"), DataType::Int64, true, ), @@ -108,7 +100,7 @@ impl AggregateUDFImpl for AvgDecimal { } fn name(&self) -> &str { - &self.name + "avg" } fn reverse_expr(&self) -> ReversedUDAF { @@ -169,8 +161,7 @@ impl PartialEq for AvgDecimal { down_cast_any_ref(other) .downcast_ref::() .map(|x| { - self.name == x.name - && self.sum_data_type == x.sum_data_type + self.sum_data_type == x.sum_data_type && self.result_data_type == x.result_data_type && self.expr.eq(&x.expr) }) @@ -212,7 +203,7 @@ impl AvgDecimalAccumulator { None => (v, false), }; - if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { // Overflow: set buffer accumulator to null self.is_not_null = false; return; @@ -380,7 +371,7 @@ impl AvgDecimalGroupsAccumulator { let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value); self.counts[group_index] += 1; - if is_overflow || validate_decimal_precision(new_sum, self.sum_precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { // Overflow: set buffer accumulator to null self.is_not_null.set_bit(group_index, false); return; diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs b/native/core/src/execution/datafusion/expressions/checkoverflow.rs index e922171bd..ed03ab667 100644 --- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs +++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs @@ -27,7 +27,8 @@ use arrow::{ datatypes::{Decimal128Type, DecimalType}, record_batch::RecordBatch, }; -use arrow_schema::{DataType, Schema}; +use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; +use arrow_schema::{DataType, Schema, DECIMAL128_MAX_PRECISION}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{DataFusionError, ScalarValue}; @@ -171,3 +172,15 @@ impl PhysicalExpr for CheckOverflow { self.hash(&mut s); } } + +/// Adapted from arrow-rs `validate_decimal_precision` but returns bool +/// instead of Err to avoid the cost of formatting the error strings and is +/// optimized to remove a memcpy that exists in the original function +/// we can remove this code once we upgrade to a version of arrow-rs that +/// includes https://github.com/apache/arrow-rs/pull/6419 +#[inline] +pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool { + precision <= DECIMAL128_MAX_PRECISION + && value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1] + && value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1] +} diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs b/native/core/src/execution/datafusion/expressions/sum_decimal.rs index e957bd25e..a3ce96b67 100644 --- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision; use crate::unlikely; use arrow::{ array::BooleanBufferBuilder, @@ -23,11 +24,10 @@ use arrow::{ use arrow_array::{ cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, }; -use arrow_data::decimal::validate_decimal_precision; use arrow_schema::{DataType, Field}; use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; -use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; @@ -36,37 +36,37 @@ use std::{any::Any, ops::BitAnd, sync::Arc}; #[derive(Debug)] pub struct SumDecimal { - name: String, + /// Aggregate function signature signature: Signature, + /// The expression that provides the input decimal values to be summed expr: Arc, - - /// The data type of the SUM result + /// The data type of the SUM result. This will always be a decimal type + /// with the same precision and scale as specified in this struct result_type: DataType, - - /// Decimal precision and scale + /// Decimal precision precision: u8, + /// Decimal scale scale: i8, - - /// Whether the result is nullable - nullable: bool, } impl SumDecimal { - pub fn new(name: impl Into, expr: Arc, data_type: DataType) -> Self { + pub fn try_new(expr: Arc, data_type: DataType) -> DFResult { // The `data_type` is the SUM result type passed from Spark side let (precision, scale) = match data_type { DataType::Decimal128(p, s) => (p, s), - _ => unreachable!(), + _ => { + return Err(DataFusionError::Internal( + "Invalid data type for SumDecimal".into(), + )) + } }; - Self { - name: name.into(), + Ok(Self { signature: Signature::user_defined(Immutable), expr, result_type: data_type, precision, scale, - nullable: true, - } + }) } } @@ -84,14 +84,14 @@ impl AggregateUDFImpl for SumDecimal { fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { let fields = vec![ - Field::new(&self.name, self.result_type.clone(), self.nullable), + Field::new(self.name(), self.result_type.clone(), self.is_nullable()), Field::new("is_empty", DataType::Boolean, false), ]; Ok(fields) } fn name(&self) -> &str { - &self.name + "sum" } fn signature(&self) -> &Signature { @@ -127,6 +127,11 @@ impl AggregateUDFImpl for SumDecimal { fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::Identical } + + fn is_nullable(&self) -> bool { + // SumDecimal is always nullable because overflows can cause null values + true + } } impl PartialEq for SumDecimal { @@ -134,12 +139,10 @@ impl PartialEq for SumDecimal { down_cast_any_ref(other) .downcast_ref::() .map(|x| { - self.name == x.name - && self.precision == x.precision - && self.scale == x.scale - && self.nullable == x.nullable - && self.result_type == x.result_type - && self.expr.eq(&x.expr) + // note that we do not compare result_type because this + // is guaranteed to match if the precision and scale + // match + self.precision == x.precision && self.scale == x.scale && self.expr.eq(&x.expr) }) .unwrap_or(false) } @@ -170,7 +173,7 @@ impl SumDecimalAccumulator { let v = unsafe { values.value_unchecked(idx) }; let (new_sum, is_overflow) = self.sum.overflowing_add(v); - if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { // Overflow: set buffer accumulator to null self.is_not_null = false; return; @@ -312,7 +315,7 @@ impl SumDecimalGroupsAccumulator { self.is_empty.set_bit(group_index, false); let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); - if is_overflow || validate_decimal_precision(new_sum, self.precision).is_err() { + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { // Overflow: set buffer accumulator to null self.is_not_null.set_bit(group_index, false); return; @@ -478,3 +481,99 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { + self.is_not_null.capacity() / 8 } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::*; + use arrow_array::builder::{Decimal128Builder, StringBuilder}; + use arrow_array::RecordBatch; + use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_common::Result; + use datafusion_execution::TaskContext; + use datafusion_expr::AggregateUDF; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{Column, Literal}; + use futures::StreamExt; + + #[test] + fn invalid_data_type() { + let expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + assert!(SumDecimal::try_new(expr, DataType::Int32).is_err()); + } + + #[tokio::test] + async fn sum_no_overflow() -> Result<()> { + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); + + let data_type = DataType::Decimal128(8, 2); + let schema = Arc::clone(&partitions[0][0].schema()); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap()); + + let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( + Arc::clone(&c1), + data_type.clone(), + )?)); + + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) + .schema(Arc::clone(&schema)) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr], + vec![None], // no filter expressions + scan, + Arc::clone(&schema), + )?); + + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch?; + } + + Ok(()) + } + + fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); + } + let decimal_array = Arc::new(decimal_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // decimal column + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); + columns.push(decimal_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() + } +} diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 663db0d1b..9000db61e 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1365,56 +1365,42 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - match datatype { + let builder = match datatype { DataType::Decimal128(_, _) => { - let func = AggregateUDF::new_from_impl(SumDecimal::new( - "sum", + let func = AggregateUDF::new_from_impl(SumDecimal::try_new( Arc::clone(&child), datatype, - )); + )?); AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("sum") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } _ => { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side let child = Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); - AggregateExprBuilder::new(sum_udaf(), vec![child]) - .schema(schema) - .alias("sum") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } - } + }; + builder + .schema(schema) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) } AggExprStruct::Avg(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); - match datatype { + let builder = match datatype { DataType::Decimal128(_, _) => { let func = AggregateUDF::new_from_impl(AvgDecimal::new( Arc::clone(&child), - "avg", datatype, input_datatype, )); AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("avg") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } _ => { // cast to the result data type of AVG if the result data type is different @@ -1428,14 +1414,15 @@ impl PhysicalPlanner { datatype, )); AggregateExprBuilder::new(Arc::new(func), vec![child]) - .schema(schema) - .alias("avg") - .with_ignore_nulls(false) - .with_distinct(false) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } - } + }; + builder + .schema(schema) + .alias("avg") + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| e.into()) } AggExprStruct::First(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;