Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support HashJoin operator #194

Merged
merged 12 commits into from
Mar 20, 2024
274 changes: 271 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Converts Spark physical plan to DataFusion physical plan

use std::{str::FromStr, sync::Arc};
use std::{collections::HashMap, str::FromStr, sync::Arc};

use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::{
Expand All @@ -37,13 +37,17 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::{utils::JoinFilter, HashJoinExec, PartitionMode},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
};
use datafusion_common::ScalarValue;
use datafusion_common::{
tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion},
JoinType as DFJoinType, ScalarValue,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
Expand Down Expand Up @@ -76,7 +80,7 @@ use crate::{
agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr,
ScalarFunc,
},
spark_operator::{operator::OpStruct, Operator},
spark_operator::{operator::OpStruct, JoinType, Operator},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
Expand Down Expand Up @@ -858,6 +862,133 @@ impl PhysicalPlanner {
Arc::new(CometExpandExec::new(projections, child, schema)),
))
}
OpStruct::HashJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs: Vec<_> = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs: Vec<_> = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};
viirya marked this conversation as resolved.
Show resolved Hide resolved

// Handle join filter as DataFusion `JoinFilter` struct
let join_filter = if let Some(expr) = &join.condition {
let left_schema = left.schema();
let right_schema = right.schema();
let left_fields = left_schema.fields();
let right_fields = right_schema.fields();
let all_fields: Vec<_> = left_fields
.into_iter()
.chain(right_fields)
.cloned()
.collect();
let full_schema = Arc::new(Schema::new(all_fields));

let physical_expr = self.create_expr(expr, full_schema)?;
let (left_field_indices, right_field_indices) = expr_to_columns(
&physical_expr,
left.schema().fields.len(),
viirya marked this conversation as resolved.
Show resolved Hide resolved
right.schema().fields.len(),
viirya marked this conversation as resolved.
Show resolved Hide resolved
)?;
let column_indices = JoinFilter::build_column_indices(
left_field_indices.clone(),
right_field_indices.clone(),
);

let filter_fields: Vec<Field> = left_field_indices
.clone()
.into_iter()
.map(|i| left.schema().field(i).clone())
.chain(
right_field_indices
.clone()
.into_iter()
.map(|i| right.schema().field(i).clone()),
)
.collect_vec();

let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new());

// Rewrite the physical expression to use the new column indices.
// DataFusion's join filter is bound to intermediate schema which contains
// only the fields used in the filter expression. But the Spark's join filter
// expression is bound to the full schema. We need to rewrite the physical
// expression to use the new column indices.
let rewritten_physical_expr = rewrite_physical_expr(
physical_expr,
left_schema.fields.len(),
right_schema.fields.len(),
&left_field_indices,
&right_field_indices,
)?;

Some(JoinFilter::new(
rewritten_physical_expr,
column_indices,
filter_schema,
))
} else {
None
};

// DataFusion `HashJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if op_reuse_array(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if op_reuse_array(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

let join = Arc::new(HashJoinExec::try_new(
left,
right,
join_on,
join_filter,
&join_type,
PartitionMode::Partitioned,
false,
)?);

Ok((left_scans, join))
}
}
}

Expand Down Expand Up @@ -1026,6 +1157,143 @@ impl From<ExpressionError> for DataFusionError {
}
}

/// Returns true if given operator can return input array as output array without
/// modification. This is used to determine if we need to copy the input batch to avoid
/// data corruption from reusing the input batch.
fn op_reuse_array(op: &Arc<dyn ExecutionPlan>) -> bool {
op.as_any().downcast_ref::<ScanExec>().is_some()
|| op.as_any().downcast_ref::<LocalLimitExec>().is_some()
|| op.as_any().downcast_ref::<ProjectionExec>().is_some()
|| op.as_any().downcast_ref::<FilterExec>().is_some()
}

/// Collects the indices of the columns in the input schema that are used in the expression
/// and returns them as a pair of vectors, one for the left side and one for the right side.
fn expr_to_columns(
expr: &Arc<dyn PhysicalExpr>,
left_field_len: usize,
right_field_len: usize,
) -> Result<(Vec<usize>, Vec<usize>), ExecutionError> {
let mut left_field_indices: Vec<usize> = vec![];
let mut right_field_indices: Vec<usize> = vec![];

expr.apply(&mut |expr| {
Ok({
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if column.index() > left_field_len + right_field_len {
return Err(DataFusionError::Internal(format!(
"Column index {} out of range",
column.index()
)));
} else if column.index() < left_field_len {
left_field_indices.push(column.index());
} else {
right_field_indices.push(column.index() - left_field_len);
}
}
VisitRecursion::Continue
})
})?;

left_field_indices.sort();
right_field_indices.sort();

Ok((left_field_indices, right_field_indices))
}

/// A physical join filter rewritter which rewrites the column indices in the expression
/// to use the new column indices. See `rewrite_physical_expr`.
struct JoinFilterRewriter<'a> {
left_field_len: usize,
right_field_len: usize,
left_field_indices: &'a [usize],
right_field_indices: &'a [usize],
}

impl JoinFilterRewriter<'_> {
fn new<'a>(
left_field_len: usize,
right_field_len: usize,
left_field_indices: &'a [usize],
right_field_indices: &'a [usize],
) -> JoinFilterRewriter<'a> {
JoinFilterRewriter {
left_field_len,
right_field_len,
left_field_indices,
right_field_indices,
}
}
}

impl TreeNodeRewriter for JoinFilterRewriter<'_> {
type N = Arc<dyn PhysicalExpr>;

fn mutate(&mut self, node: Self::N) -> datafusion_common::Result<Self::N> {
let new_expr: Arc<dyn PhysicalExpr> =
if let Some(column) = node.as_any().downcast_ref::<Column>() {
if column.index() < self.left_field_len {
// left side
let new_index = self
.left_field_indices
.iter()
.position(|&x| x == column.index())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Column index {} not found in left field indices",
column.index()
))
})?;
Arc::new(Column::new(column.name(), new_index))
} else if column.index() < self.left_field_len + self.right_field_len {
// right side
let new_index = self
.right_field_indices
.iter()
.position(|&x| x + self.left_field_len == column.index())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Column index {} not found in right field indices",
column.index()
))
})?;
Arc::new(Column::new(
column.name(),
new_index + self.left_field_indices.len(),
))
} else {
return Err(DataFusionError::Internal(format!(
"Column index {} out of range",
column.index()
)));
}
} else {
node.clone()
};
Ok(new_expr)
}
}

/// Rewrites the physical expression to use the new column indices.
/// This is necessary when the physical expression is used in a join filter, as the column
/// indices are different from the original schema.
fn rewrite_physical_expr(
expr: Arc<dyn PhysicalExpr>,
left_field_len: usize,
right_field_len: usize,
left_field_indices: &[usize],
right_field_indices: &[usize],
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let mut rewriter = JoinFilterRewriter::new(
left_field_len,
right_field_len,
left_field_indices,
right_field_indices,
);

Ok(expr.rewrite(&mut rewriter)?)
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec {
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.input.children()
vec![self.input.clone()]
}

fn with_new_children(
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message Operator {
Limit limit = 105;
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
HashJoin hash_join = 108;
}
}

Expand Down Expand Up @@ -87,3 +88,21 @@ message Expand {
repeated spark.spark_expression.Expr project_list = 1;
int32 num_expr_per_project = 3;
}

message HashJoin {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we name it as Join so that we can use it for both SMJ and SHJ / BHJ and I will further use it in BNLJ change I am working on ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are different join operators. I'm not sure how we use same Join to represent them?

Copy link

@singhpk234 singhpk234 Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apologies, I wasn't clear in the comment above, I was thinking of something like this :

message Join {
  repeated spark.spark_expression.Expr left_join_keys = 1;
  repeated spark.spark_expression.Expr right_join_keys = 2;
  JoinType join_type = 3;
  // can serve as condition in SHJ and sort_options in SMJ
  repeated spark.spark_expression.Expr join_exprs = 4;
  JoinExec join_exec = 5;
}

message JoinExec {
 HashJoin = 0;
 SMJ = 1;
....
}

may be it's too much and having diff proto msg for each join should be right thing to do !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it looks more complicated to me. And one message per join operator looks simple to me.

repeated spark.spark_expression.Expr left_join_keys = 1;
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
optional spark.spark_expression.Expr condition = 4;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
RightOuter = 2;
FullOuter = 3;
LeftSemi = 4;
RightSemi = 5;
LeftAnti = 6;
RightAnti = 7;
}
Loading
Loading