diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs index da3c54c21..5b31748d3 100644 --- a/native/core/benches/aggregate.rs +++ b/native/core/benches/aggregate.rs @@ -44,64 +44,60 @@ fn criterion_benchmark(c: &mut Criterion) { batches.push(batch.clone()); } let partitions = &[batches]; - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let schema = scan.schema().clone(); - let c0: Arc = Arc::new(Column::new("c0", 0)); let c1: Arc = Arc::new(Column::new("c1", 1)); let rt = Runtime::new().unwrap(); - let datafusion_sum_decimal = sum_udaf(); group.bench_function("aggregate - sum decimal (DataFusion)", |b| { - b.to_async(&rt).iter(|| async { - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let aggregate = create_aggregate( - scan, + let datafusion_sum_decimal = sum_udaf(); + b.to_async(&rt).iter(|| { + agg_test( + partitions, c0.clone(), c1.clone(), - &schema, datafusion_sum_decimal.clone(), - ); - let mut x = aggregate - .execute(0, Arc::new(TaskContext::default())) - .unwrap(); - while let Some(batch) = x.next().await { - let _batch = batch.unwrap(); - } + ) }) }); - let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( - "sum", - Arc::clone(&c1), - DataType::Decimal128(7, 2), - ))); group.bench_function("aggregate - sum decimal (Comet)", |b| { - b.to_async(&rt).iter(|| async { - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap()); - let aggregate = create_aggregate( - scan, + let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new( + "sum", + Arc::clone(&c1), + DataType::Decimal128(7, 2), + ))); + b.to_async(&rt).iter(|| { + agg_test( + partitions, c0.clone(), c1.clone(), - &schema, comet_sum_decimal.clone(), - ); - let mut x = aggregate - .execute(0, Arc::new(TaskContext::default())) - .unwrap(); - while let Some(batch) = x.next().await { - let _batch = batch.unwrap(); - } + ) }) }); group.finish(); } +async fn agg_test( + partitions: &[Vec], + c0: Arc, + c1: Arc, + aggregate_udf: Arc, +) { + let schema = &partitions[0][0].schema(); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap()); + let aggregate = create_aggregate(scan, c0.clone(), c1.clone(), &schema, aggregate_udf); + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch.unwrap(); + } +} + fn create_aggregate( scan: Arc, c0: Arc, @@ -146,7 +142,7 @@ fn create_record_batch(num_rows: usize) -> RecordBatch { // string column fields.push(Field::new(format!("c0"), DataType::Utf8, false)); - columns.push(Arc::clone(&string_array)); + columns.push(string_array); // decimal column fields.push(Field::new( @@ -154,7 +150,7 @@ fn create_record_batch(num_rows: usize) -> RecordBatch { DataType::Decimal128(38, 10), false, )); - columns.push(Arc::clone(&decimal_array)); + columns.push(decimal_array); let schema = Schema::new(fields); RecordBatch::try_new(Arc::new(schema), columns).unwrap()