Skip to content

Commit

Permalink
[MINOR] Reduce complexity on SHJ (#7607)
Browse files Browse the repository at this point in the history
* Before fix

* Clippy

* Minor changes

* Simplifications

---------

Co-authored-by: Mustafa Akur <mustafa.akur@synnada.ai>
  • Loading branch information
metesynnada and mustafasrepo authored Sep 22, 2023
1 parent d193508 commit ee9078d
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 128 deletions.
105 changes: 2 additions & 103 deletions datafusion/physical-plan/src/joins/hash_join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
//! related functionality, used both in join calculations and optimization rules.

use std::collections::{HashMap, VecDeque};
use std::fmt::{Debug, Formatter};
use std::fmt::Debug;
use std::ops::IndexMut;
use std::sync::Arc;
use std::{fmt, usize};

use crate::joins::utils::{JoinFilter, JoinSide};
use crate::ExecutionPlan;

use arrow::compute::concat_batches;
use arrow::datatypes::{ArrowNativeType, SchemaRef};
Expand All @@ -34,13 +33,12 @@ use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalBound};
use datafusion_physical_expr::intervals::{Interval, IntervalBound};
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};

use hashbrown::raw::RawTable;
use hashbrown::HashSet;
use parking_lot::Mutex;

// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value.
// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side,
Expand Down Expand Up @@ -446,105 +444,6 @@ fn convert_filter_columns(
})
}

#[derive(Default)]
pub struct IntervalCalculatorInnerState {
/// Expression graph for interval calculations
graph: Option<ExprIntervalGraph>,
sorted_exprs: Vec<Option<SortedFilterExpr>>,
calculated: bool,
}

impl Debug for IntervalCalculatorInnerState {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Exprs({:?})", self.sorted_exprs)
}
}

pub fn build_filter_expression_graph(
interval_state: &Arc<Mutex<IntervalCalculatorInnerState>>,
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
filter: &JoinFilter,
) -> Result<(
Option<SortedFilterExpr>,
Option<SortedFilterExpr>,
Option<ExprIntervalGraph>,
)> {
// Lock the mutex of the interval state:
let mut filter_state = interval_state.lock();
// If this is the first partition to be invoked, then we need to initialize our state
// (the expression graph for pruning, sorted filter expressions etc.)
if !filter_state.calculated {
// Interval calculations require each column to exhibit monotonicity
// independently. However, a `PhysicalSortExpr` object defines a
// lexicographical ordering, so we can only use their first elements.
// when deducing column monotonicities.
// TODO: Extend the `PhysicalSortExpr` mechanism to express independent
// (i.e. simultaneous) ordering properties of columns.

// Build sorted filter expressions for the left and right join side:
let join_sides = [JoinSide::Left, JoinSide::Right];
let children = [left, right];
for (join_side, child) in join_sides.iter().zip(children.iter()) {
let sorted_expr = child
.output_ordering()
.and_then(|orders| {
build_filter_input_order(
*join_side,
filter,
&child.schema(),
&orders[0],
)
.transpose()
})
.transpose()?;

filter_state.sorted_exprs.push(sorted_expr);
}

// Collect available sorted filter expressions:
let sorted_exprs_size = filter_state.sorted_exprs.len();
let mut sorted_exprs = filter_state
.sorted_exprs
.iter_mut()
.flatten()
.collect::<Vec<_>>();

// Create the expression graph if we can create sorted filter expressions for both children:
filter_state.graph = if sorted_exprs.len() == sorted_exprs_size {
let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?;

// Gather filter expressions:
let filter_exprs = sorted_exprs
.iter()
.map(|sorted_expr| sorted_expr.filter_expr().clone())
.collect::<Vec<_>>();

// Gather node indices of converted filter expressions in `SortedFilterExpr`s
// using the filter columns vector:
let child_node_indices = graph.gather_node_indices(&filter_exprs);

// Update SortedFilterExpr instances with the corresponding node indices:
for (sorted_expr, (_, index)) in
sorted_exprs.iter_mut().zip(child_node_indices.iter())
{
sorted_expr.set_node_index(*index);
}

Some(graph)
} else {
None
};
filter_state.calculated = true;
}
// Return the sorted filter expressions for both sides along with the expression graph:
Ok((
filter_state.sorted_exprs[0].clone(),
filter_state.sorted_exprs[1].clone(),
filter_state.graph.as_ref().cloned(),
))
}

/// The [SortedFilterExpr] object represents a sorted filter expression. It
/// contains the following information: The origin expression, the filter
/// expression, an interval encapsulating expression bounds, and a stable
Expand Down
46 changes: 21 additions & 25 deletions datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ use std::{any::Any, usize};
use crate::common::SharedMemoryReservation;
use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
use crate::joins::hash_join_utils::{
build_filter_expression_graph, calculate_filter_expr_intervals, combine_two_batches,
calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
get_pruning_semi_indices, record_visited_indices, IntervalCalculatorInnerState,
PruningJoinHashMap,
get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap,
};
use crate::joins::StreamJoinPartitionMode;
use crate::DisplayAs;
Expand Down Expand Up @@ -69,6 +68,7 @@ use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::intervals::ExprIntervalGraph;

use crate::joins::utils::prepare_sorted_exprs;
use ahash::RandomState;
use futures::stream::{select, BoxStream};
use futures::{Stream, StreamExt};
Expand Down Expand Up @@ -174,8 +174,6 @@ pub struct SymmetricHashJoinExec {
pub(crate) filter: Option<JoinFilter>,
/// How the join is performed
pub(crate) join_type: JoinType,
/// Expression graph and `SortedFilterExpr`s for interval calculations
filter_state: Option<Arc<Mutex<IntervalCalculatorInnerState>>>,
/// The schema once the join is applied
schema: SchemaRef,
/// Shares the `RandomState` for the hashing algorithm
Expand Down Expand Up @@ -285,20 +283,12 @@ impl SymmetricHashJoinExec {
// Initialize the random state for the join operation:
let random_state = RandomState::with_seeds(0, 0, 0, 0);

let filter_state = if filter.is_some() {
let inner_state = IntervalCalculatorInnerState::default();
Some(Arc::new(Mutex::new(inner_state)))
} else {
None
};

Ok(SymmetricHashJoinExec {
left,
right,
on,
filter,
join_type: *join_type,
filter_state,
schema: Arc::new(schema),
random_state,
metrics: ExecutionPlanMetricsSet::new(),
Expand Down Expand Up @@ -496,21 +486,27 @@ impl ExecutionPlan for SymmetricHashJoinExec {
);
}
// If `filter_state` and `filter` are both present, then calculate sorted filter expressions
// for both sides, and build an expression graph if one is not already built.
let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
match (&self.filter_state, &self.filter) {
(Some(interval_state), Some(filter)) => build_filter_expression_graph(
interval_state,
// for both sides, and build an expression graph.
let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (
self.left.output_ordering(),
self.right.output_ordering(),
&self.filter,
) {
(Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
let (left, right, graph) = prepare_sorted_exprs(
filter,
&self.left,
&self.right,
filter,
)?,
// If `filter_state` or `filter` is not present, then return None for all three values:
(_, _) => (None, None, None),
};
left_sort_exprs,
right_sort_exprs,
)?;
(Some(left), Some(right), Some(graph))
}
// If `filter_state` or `filter` is not present, then return None for all three values:
_ => (None, None, None),
};

let on_left = self.on.iter().map(|on| on.0.clone()).collect::<Vec<_>>();
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
let (on_left, on_right) = self.on.iter().cloned().unzip();

let left_side_joiner =
OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
Expand Down
87 changes: 87 additions & 0 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ use datafusion_physical_expr::{
PhysicalSortExpr,
};

use crate::joins::hash_join_utils::{build_filter_input_order, SortedFilterExpr};
use datafusion_physical_expr::intervals::ExprIntervalGraph;
use datafusion_physical_expr::utils::merge_vectors;
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
Expand Down Expand Up @@ -1295,6 +1297,91 @@ impl BuildProbeJoinMetrics {
}
}

/// Updates sorted filter expressions with corresponding node indices from the
/// expression interval graph.
///
/// This function iterates through the provided sorted filter expressions,
/// gathers the corresponding node indices from the expression interval graph,
/// and then updates the sorted expressions with these indices. It ensures
/// that these sorted expressions are aligned with the structure of the graph.
fn update_sorted_exprs_with_node_indices(
graph: &mut ExprIntervalGraph,
sorted_exprs: &mut [SortedFilterExpr],
) {
// Extract filter expressions from the sorted expressions:
let filter_exprs = sorted_exprs
.iter()
.map(|expr| expr.filter_expr().clone())
.collect::<Vec<_>>();

// Gather corresponding node indices for the extracted filter expressions from the graph:
let child_node_indices = graph.gather_node_indices(&filter_exprs);

// Iterate through the sorted expressions and the gathered node indices:
for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) {
// Update each sorted expression with the corresponding node index:
sorted_expr.set_node_index(index);
}
}

/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions.
///
/// # Arguments
///
/// * `filter` - The join filter to base the sorting on.
/// * `left` - The left execution plan.
/// * `right` - The right execution plan.
/// * `left_sort_exprs` - The expressions to sort on the left side.
/// * `right_sort_exprs` - The expressions to sort on the right side.
///
/// # Returns
///
/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph.
pub fn prepare_sorted_exprs(
filter: &JoinFilter,
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
left_sort_exprs: &[PhysicalSortExpr],
right_sort_exprs: &[PhysicalSortExpr],
) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
// Build the filter order for the left side
let err =
|| DataFusionError::Plan("Filter does not include the child order".to_owned());

let left_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Left,
filter,
&left.schema(),
&left_sort_exprs[0],
)?
.ok_or_else(err)?;

// Build the filter order for the right side
let right_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Right,
filter,
&right.schema(),
&right_sort_exprs[0],
)?
.ok_or_else(err)?;

// Collect the sorted expressions
let mut sorted_exprs =
vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];

// Build the expression interval graph
let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?;

// Update sorted expressions with node indices
update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);

// Swap and remove to get the final sorted filter expressions
let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
let left_sorted_filter_expr = sorted_exprs.swap_remove(0);

Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit ee9078d

Please sign in to comment.