Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Sep 18, 2024
1 parent af9fc15 commit 36260ad
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions native/core/benches/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,64 +44,60 @@ fn criterion_benchmark(c: &mut Criterion) {
batches.push(batch.clone());
}
let partitions = &[batches];
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap());
let schema = scan.schema().clone();

let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));

let rt = Runtime::new().unwrap();

let datafusion_sum_decimal = sum_udaf();
group.bench_function("aggregate - sum decimal (DataFusion)", |b| {
b.to_async(&rt).iter(|| async {
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap());
let aggregate = create_aggregate(
scan,
let datafusion_sum_decimal = sum_udaf();
b.to_async(&rt).iter(|| {
agg_test(
partitions,
c0.clone(),
c1.clone(),
&schema,
datafusion_sum_decimal.clone(),
);
let mut x = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = x.next().await {
let _batch = batch.unwrap();
}
)
})
});

let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new(
"sum",
Arc::clone(&c1),
DataType::Decimal128(7, 2),
)));
group.bench_function("aggregate - sum decimal (Comet)", |b| {
b.to_async(&rt).iter(|| async {
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, batch.schema(), None).unwrap());
let aggregate = create_aggregate(
scan,
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(SumDecimal::new(
"sum",
Arc::clone(&c1),
DataType::Decimal128(7, 2),
)));
b.to_async(&rt).iter(|| {
agg_test(
partitions,
c0.clone(),
c1.clone(),
&schema,
comet_sum_decimal.clone(),
);
let mut x = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = x.next().await {
let _batch = batch.unwrap();
}
)
})
});

group.finish();
}

async fn agg_test(
partitions: &[Vec<RecordBatch>],
c0: Arc<dyn PhysicalExpr>,
c1: Arc<dyn PhysicalExpr>,
aggregate_udf: Arc<AggregateUDF>,
) {
let schema = &partitions[0][0].schema();
let scan: Arc<dyn ExecutionPlan> =
Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap());
let aggregate = create_aggregate(scan, c0.clone(), c1.clone(), &schema, aggregate_udf);
let mut stream = aggregate
.execute(0, Arc::new(TaskContext::default()))
.unwrap();
while let Some(batch) = stream.next().await {
let _batch = batch.unwrap();
}
}

fn create_aggregate(
scan: Arc<dyn ExecutionPlan>,
c0: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -146,15 +142,15 @@ fn create_record_batch(num_rows: usize) -> RecordBatch {

// string column
fields.push(Field::new(format!("c0"), DataType::Utf8, false));
columns.push(Arc::clone(&string_array));
columns.push(string_array);

// decimal column
fields.push(Field::new(
format!("c1"),
DataType::Decimal128(38, 10),
false,
));
columns.push(Arc::clone(&decimal_array));
columns.push(decimal_array);

let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
Expand Down

0 comments on commit 36260ad

Please sign in to comment.