Skip to content

Commit

Permalink
draft: ported join limited output
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Jan 8, 2024
1 parent e0bd40b commit a261e44
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 76 deletions.
187 changes: 117 additions & 70 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use arrow::util::bit_util;
use arrow_array::cast::downcast_array;
use arrow_schema::ArrowError;
use datafusion_common::{
exec_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
};
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -915,12 +915,16 @@ enum HashJoinStreamState {
struct ProcessProbeBatchState {
/// Current probe-side batch
batch: RecordBatch,
/// Matching offset
offset: Option<(usize, u64)>,
/// Last matched probe-side index
last_matched_probe_idx: Option<usize>,
}

impl HashJoinStreamState {
/// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum.
/// Returns an error if state is not ProcessProbeBatchState.
fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> {
fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> {
match self {
HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
_ => internal_err!("Expected hash join stream in ProcessProbeBatch state"),
Expand Down Expand Up @@ -1032,7 +1036,14 @@ pub(crate) fn build_equal_condition_join_indices<T: JoinHashMapType>(
build_side: JoinSide,
deleted_offset: Option<usize>,
fifo_hashmap: bool,
) -> Result<(UInt64Array, UInt32Array)> {
output_limit: usize,
output_offset: Option<(usize, u64)>,
) -> Result<(
UInt64Array,
UInt32Array,
Option<(usize, u64)>,
Option<usize>,
)> {
let keys_values = probe_on
.iter()
.map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))
Expand Down Expand Up @@ -1078,16 +1089,21 @@ pub(crate) fn build_equal_condition_join_indices<T: JoinHashMapType>(
// (5,1)
//
// With this approach, the lexicographic order on both the probe side and the build side is preserved.
let (mut probe_indices, mut build_indices) = if fifo_hashmap {
build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset)
let (mut probe_indices, mut build_indices, next_offset) = if fifo_hashmap {
build_hashmap.get_n_matched_indices(
hash_values.iter().enumerate(),
deleted_offset,
output_limit,
output_offset,
)
} else {
let (mut matched_probe, mut matched_build) = build_hashmap
.get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset);

matched_probe.as_slice_mut().reverse();
matched_build.as_slice_mut().reverse();

(matched_probe, matched_build)
(matched_probe, matched_build, None)
};

let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None);
Expand All @@ -1107,15 +1123,21 @@ pub(crate) fn build_equal_condition_join_indices<T: JoinHashMapType>(
(left, right)
};

let matched_indices = equal_rows_arr(
let (left_joined, right_joined) = equal_rows_arr(
&left,
&right,
&build_join_values,
&keys_values,
null_equals_null,
)?;

Ok((matched_indices.0, matched_indices.1))
let last_matched_idx = if !right_joined.is_empty() {
Some(right_joined.value(right_joined.len() - 1) as usize)
} else {
None
};

Ok((left_joined, right_joined, next_offset, last_matched_idx))
}

// version of eq_dyn supporting equality on null arrays
Expand Down Expand Up @@ -1263,6 +1285,8 @@ impl HashJoinStream {
self.state =
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
batch,
offset: None,
last_matched_probe_idx: None,
});
}
Some(Err(err)) => return Poll::Ready(Err(err)),
Expand All @@ -1277,7 +1301,7 @@ impl HashJoinStream {
fn process_probe_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let state = self.state.try_as_process_probe_batch()?;
let state = self.state.try_as_process_probe_batch_mut()?;
let build_side = self.build_side.try_as_ready_mut()?;

self.join_metrics.input_batches.add(1);
Expand All @@ -1286,59 +1310,79 @@ impl HashJoinStream {

let mut hashes_buffer = vec![];
// get the matched two indices for the on condition
let left_right_indices = build_equal_condition_join_indices(
build_side.left_data.hash_map(),
let (left_side, right_side, next_offset, last_matched_probe_idx) =
build_equal_condition_join_indices(
build_side.left_data.hash_map(),
build_side.left_data.batch(),
&state.batch,
&self.on_left,
&self.on_right,
&self.random_state,
self.null_equals_null,
&mut hashes_buffer,
self.filter.as_ref(),
JoinSide::Left,
None,
true,
self.batch_size,
state.offset,
)?;

// set the left bitmap
// and only left, full, left semi, left anti need the left bitmap
if need_produce_result_in_final(self.join_type) {
left_side.iter().flatten().for_each(|x| {
build_side.visited_left_side.set_bit(x as usize, true);
});
}

let adjust_range_start = state.last_matched_probe_idx.map_or(0, |v| v + 1);
let adjust_range_end = if next_offset.is_none()
|| next_offset.is_some_and(|(probe_idx, build_idx)| {
probe_idx + 1 >= state.batch.num_rows() && build_idx == 0
}) {
state.batch.num_rows()
} else {
last_matched_probe_idx.map_or(0, |v| v + 1)
};

// adjust the two side indices base on the join type
let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
adjust_range_start..adjust_range_end,
self.join_type,
);

let result = build_batch_from_indices(
&self.schema,
build_side.left_data.batch(),
&state.batch,
&self.on_left,
&self.on_right,
&self.random_state,
self.null_equals_null,
&mut hashes_buffer,
self.filter.as_ref(),
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
None,
true,
);

let result = match left_right_indices {
Ok((left_side, right_side)) => {
// set the left bitmap
// and only left, full, left semi, left anti need the left bitmap
if need_produce_result_in_final(self.join_type) {
left_side.iter().flatten().for_each(|x| {
build_side.visited_left_side.set_bit(x as usize, true);
});
}

// adjust the two side indices base on the join type
let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
0..state.batch.num_rows(),
self.join_type,
);
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(state.batch.num_rows());
timer.done();

let result = build_batch_from_indices(
&self.schema,
build_side.left_data.batch(),
&state.batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(state.batch.num_rows());
result
match next_offset {
Some((probe_idx, build_idx)) => {
if probe_idx + 1 >= state.batch.num_rows() && build_idx == 0 {
self.state = HashJoinStreamState::FetchProbeBatch;
} else {
state.offset = Some((probe_idx, build_idx));
if last_matched_probe_idx.is_some() {
state.last_matched_probe_idx = last_matched_probe_idx;
};
}
}
Err(err) => {
exec_err!("Fail to build join indices in HashJoinExec, error:{err}")
None => {
self.state = HashJoinStreamState::FetchProbeBatch;
}
};
timer.done();

self.state = HashJoinStreamState::FetchProbeBatch;

Ok(StatefulStreamResult::Ready(Some(result?)))
}
Expand Down Expand Up @@ -1414,7 +1458,8 @@ mod tests {
use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue,
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err,
ScalarValue,
};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
Expand Down Expand Up @@ -2914,7 +2959,7 @@ mod tests {

let join_hash_map = JoinHashMap::new(hashmap_left, next);

let (l, r) = build_equal_condition_join_indices(
let (l, r, _, _) = build_equal_condition_join_indices(
&join_hash_map,
&left,
&right,
Expand All @@ -2927,6 +2972,8 @@ mod tests {
JoinSide::Left,
None,
false,
8192,
None,
)?;

let mut left_ids = UInt64Builder::with_capacity(0);
Expand Down Expand Up @@ -3314,26 +3361,26 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| 4 | 1 | 0 | 10 | 1 | 0 |",
"| 3 | 1 | 0 | 10 | 1 | 0 |",
"| 2 | 1 | 0 | 10 | 1 | 0 |",
"| 1 | 1 | 0 | 10 | 1 | 0 |",
"| 4 | 1 | 0 | 20 | 1 | 0 |",
"| 3 | 1 | 0 | 20 | 1 | 0 |",
"| 2 | 1 | 0 | 20 | 1 | 0 |",
"| 2 | 1 | 0 | 10 | 1 | 0 |",
"| 3 | 1 | 0 | 10 | 1 | 0 |",
"| 4 | 1 | 0 | 10 | 1 | 0 |",
"| 1 | 1 | 0 | 20 | 1 | 0 |",
"| 4 | 1 | 0 | 30 | 1 | 0 |",
"| 3 | 1 | 0 | 30 | 1 | 0 |",
"| 2 | 1 | 0 | 30 | 1 | 0 |",
"| 2 | 1 | 0 | 20 | 1 | 0 |",
"| 3 | 1 | 0 | 20 | 1 | 0 |",
"| 4 | 1 | 0 | 20 | 1 | 0 |",
"| 1 | 1 | 0 | 30 | 1 | 0 |",
"| 4 | 1 | 0 | 40 | 1 | 0 |",
"| 3 | 1 | 0 | 40 | 1 | 0 |",
"| 2 | 1 | 0 | 40 | 1 | 0 |",
"| 2 | 1 | 0 | 30 | 1 | 0 |",
"| 3 | 1 | 0 | 30 | 1 | 0 |",
"| 4 | 1 | 0 | 30 | 1 | 0 |",
"| 1 | 1 | 0 | 40 | 1 | 0 |",
"| 4 | 1 | 0 | 50 | 1 | 0 |",
"| 3 | 1 | 0 | 50 | 1 | 0 |",
"| 2 | 1 | 0 | 50 | 1 | 0 |",
"| 2 | 1 | 0 | 40 | 1 | 0 |",
"| 3 | 1 | 0 | 40 | 1 | 0 |",
"| 4 | 1 | 0 | 40 | 1 | 0 |",
"| 1 | 1 | 0 | 50 | 1 | 0 |",
"| 2 | 1 | 0 | 50 | 1 | 0 |",
"| 3 | 1 | 0 | 50 | 1 | 0 |",
"| 4 | 1 | 0 | 50 | 1 | 0 |",
"+----+----+----+----+----+----+",
];
let left_batch = [
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ pub(crate) fn join_with_probe_batch(
if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(None);
}
let (build_indices, probe_indices) = build_equal_condition_join_indices(
let (build_indices, probe_indices, _, _) = build_equal_condition_join_indices(
&build_hash_joiner.hashmap,
&build_hash_joiner.input_buffer,
probe_batch,
Expand All @@ -772,6 +772,8 @@ pub(crate) fn join_with_probe_batch(
build_hash_joiner.build_side,
Some(build_hash_joiner.deleted_offset),
false,
usize::MAX,
None,
)?;

if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
Expand Down
Loading

0 comments on commit a261e44

Please sign in to comment.