Skip to content

Commit

Permalink
independent hash per batch
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Oct 3, 2024
1 parent 80aae67 commit ed8e135
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ pub(crate) struct GroupedHashAggregateStream {
/// Indicates whether we skip the partial aggregation
skip_partial_aggregation: bool,

/// Store hashes
hashes_buffer: Vec<Vec<u64>>,

/// Random state for creating hashes
random_state: RandomState,

Expand Down Expand Up @@ -528,7 +525,6 @@ impl GroupedHashAggregateStream {
spill_state,
group_values_soft_limit: agg.limit,
skip_aggregation_probe,
hashes_buffer: Default::default(),
random_state: Default::default(),
skip_partial_aggregation: false,
})
Expand Down Expand Up @@ -588,8 +584,7 @@ impl Stream for GroupedHashAggregateStream {
// Do the grouping
let group_by_values =
self.evalute_grouping_expressions(&batch)?;

self.compute_group_by_hashes_and_update_skip_aggregation_probe(
let hashes_buffer = self.compute_group_by_hashes_and_update_skip_aggregation_probe(
&group_by_values,
)?;
if self.skip_partial_aggregation {
Expand All @@ -598,10 +593,11 @@ impl Stream for GroupedHashAggregateStream {
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}

extract_ok!(
self.group_aggregate_batch(batch, group_by_values)
);
extract_ok!(self.group_aggregate_batch(
batch,
group_by_values,
&hashes_buffer
));

// If we can begin emitting rows, do so,
// otherwise keep consuming input
Expand Down Expand Up @@ -636,10 +632,13 @@ impl Stream for GroupedHashAggregateStream {
// Do the grouping
let group_by_values =
self.evalute_grouping_expressions(&batch)?;
self.compute_group_by_hashes(&group_by_values)?;
extract_ok!(
self.group_aggregate_batch(batch, group_by_values)
);
let hashes_buffer =
self.compute_group_by_hashes(&group_by_values)?;
extract_ok!(self.group_aggregate_batch(
batch,
group_by_values,
&hashes_buffer
));

// If we can begin emitting rows, do so,
// otherwise keep consuming input
Expand Down Expand Up @@ -678,10 +677,13 @@ impl Stream for GroupedHashAggregateStream {
// Do the grouping
let group_by_values =
self.evalute_grouping_expressions(&batch)?;
self.compute_group_by_hashes(&group_by_values)?;
extract_ok!(
self.group_aggregate_batch(batch, group_by_values)
);
let hashes_buffer =
self.compute_group_by_hashes(&group_by_values)?;
extract_ok!(self.group_aggregate_batch(
batch,
group_by_values,
&hashes_buffer
));

// If we can begin emitting rows, do so,
// otherwise keep consuming input
Expand Down Expand Up @@ -795,39 +797,41 @@ impl GroupedHashAggregateStream {
fn compute_group_by_hashes_and_update_skip_aggregation_probe(
&mut self,
group_by_values: &[Vec<ArrayRef>],
) -> Result<()> {
self.hashes_buffer.resize(group_by_values.len(), Vec::new());
) -> Result<Vec<Vec<u64>>> {
let mut hashes_buffer: Vec<Vec<u64>> = Vec::default();
hashes_buffer.resize(group_by_values.len(), Vec::new());
for (index, group_values) in group_by_values.iter().enumerate() {
let n_rows = group_values[0].len();
let batch_hashes = &mut self.hashes_buffer[index];
let batch_hashes = &mut hashes_buffer[index];
batch_hashes.resize(n_rows, 0);
create_hashes(group_values, &self.random_state, batch_hashes)?;

// This function should be called if skip aggregation is supported
let probe = self.skip_aggregation_probe.as_mut().unwrap();
self.skip_partial_aggregation = probe.update_state(batch_hashes);
if self.skip_partial_aggregation {
return Ok(());
return Ok(vec![]);
}
}

Ok(())
Ok(hashes_buffer)
}

/// compute hashes without counting hashes
fn compute_group_by_hashes(
&mut self,
group_by_values: &[Vec<ArrayRef>],
) -> Result<()> {
self.hashes_buffer.resize(group_by_values.len(), Vec::new());
) -> Result<Vec<Vec<u64>>> {
let mut hashes_buffer: Vec<Vec<u64>> = Vec::default();
hashes_buffer.resize(group_by_values.len(), Vec::new());
for (index, group_values) in group_by_values.iter().enumerate() {
let n_rows = group_values[0].len();
let batch_hashes = &mut self.hashes_buffer[index];
let batch_hashes = &mut hashes_buffer[index];
batch_hashes.resize(n_rows, 0);
create_hashes(group_values, &self.random_state, batch_hashes)?;
}

Ok(())
Ok(hashes_buffer)
}

fn evalute_grouping_expressions(
Expand All @@ -846,6 +850,7 @@ impl GroupedHashAggregateStream {
&mut self,
batch: RecordBatch,
group_by_values: Vec<Vec<ArrayRef>>,
hashes_buffer: &[Vec<u64>],
) -> Result<()> {
// Evaluate the aggregation expressions.
let input_values = if self.spill_state.is_stream_merging {
Expand All @@ -868,7 +873,7 @@ impl GroupedHashAggregateStream {
self.group_values.intern(
group_values,
&mut self.current_group_indices,
&self.hashes_buffer[index],
&hashes_buffer[index],
)?;

let group_indices = &self.current_group_indices;
Expand Down

0 comments on commit ed8e135

Please sign in to comment.