Skip to content

Commit

Permalink
feat: Support HashJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 13, 2024
1 parent 81a641f commit 8078466
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 8 deletions.
153 changes: 150 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Converts Spark physical plan to DataFusion physical plan

use std::{str::FromStr, sync::Arc};
use std::{collections::HashMap, str::FromStr, sync::Arc};

use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::{
Expand All @@ -37,13 +37,17 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::{utils::JoinFilter, HashJoinExec, PartitionMode},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
};
use datafusion_common::ScalarValue;
use datafusion_common::{
tree_node::{TreeNode, VisitRecursion},
JoinType as DFJoinType, ScalarValue,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
Expand Down Expand Up @@ -76,7 +80,7 @@ use crate::{
agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr,
ScalarFunc,
},
spark_operator::{operator::OpStruct, Operator},
spark_operator::{operator::OpStruct, JoinType, Operator},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
Expand Down Expand Up @@ -858,6 +862,107 @@ impl PhysicalPlanner {
Arc::new(CometExpandExec::new(projections, child, schema)),
))
}
OpStruct::HashJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs: Vec<_> = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs: Vec<_> = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};

// Handle join filter as DataFusion `JoinFilter` struct
let join_filter = if let Some(expr) = &join.condition {
let physical_expr = self.create_expr(expr, left.schema())?;
let (left_field_indices, right_field_indices) = expr_to_columns(
&physical_expr,
left.schema().fields.len(),
right.schema().fields.len(),
)?;
let column_indices = JoinFilter::build_column_indices(
left_field_indices.clone(),
right_field_indices.clone(),
);

let filter_fields: Vec<Field> = left_field_indices
.into_iter()
.map(|i| left.schema().field(i).clone())
.chain(
right_field_indices
.into_iter()
.map(|i| right.schema().field(i).clone()),
)
.collect_vec();

let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new());

Some(JoinFilter::new(
physical_expr,
column_indices,
filter_schema,
))
} else {
None
};

// DataFusion `HashJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if !is_op_do_copying(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if !is_op_do_copying(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

let join = Arc::new(HashJoinExec::try_new(
left,
right,
join_on,
join_filter,
&join_type,
PartitionMode::Partitioned,
false,
)?);

Ok((left_scans, join))
}
}
}

Expand Down Expand Up @@ -1026,6 +1131,48 @@ impl From<ExpressionError> for DataFusionError {
}
}

/// Returns true if given operator copies input batch to avoid data corruption from reusing
/// input arrays.
fn is_op_do_copying(op: &Arc<dyn ExecutionPlan>) -> bool {
op.as_any().downcast_ref::<CopyExec>().is_some()
|| op.as_any().downcast_ref::<CometExpandExec>().is_some()
|| op.as_any().downcast_ref::<SortExec>().is_some()
}

/// Collects the indices of the columns in the input schema that are used in the expression
/// and returns them as a pair of vectors, one for the left side and one for the right side.
fn expr_to_columns(
expr: &Arc<dyn PhysicalExpr>,
left_field_len: usize,
right_field_len: usize,
) -> Result<(Vec<usize>, Vec<usize>), ExecutionError> {
let mut left_field_indices: Vec<usize> = vec![];
let mut right_field_indices: Vec<usize> = vec![];

expr.apply(&mut |expr| {
Ok({
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if column.index() > left_field_len + right_field_len {
return Err(DataFusionError::Internal(format!(
"Column index {} out of range",
column.index()
)));
} else if column.index() < left_field_len {
left_field_indices.push(column.index());
} else {
right_field_indices.push(column.index() - left_field_len);
}
}
VisitRecursion::Continue
})
})?;

left_field_indices.sort();
right_field_indices.sort();

Ok((left_field_indices, right_field_indices))
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec {
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.input.children()
vec![self.input.clone()]
}

fn with_new_children(
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message Operator {
Limit limit = 105;
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
HashJoin hash_join = 108;
}
}

Expand Down Expand Up @@ -87,3 +88,21 @@ message Expand {
repeated spark.spark_expression.Expr project_list = 1;
int32 num_expr_per_project = 3;
}

message HashJoin {
repeated spark.spark_expression.Expr left_join_keys = 1;
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
optional spark.spark_expression.Expr condition = 4;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
RightOuter = 2;
FullOuter = 3;
LeftSemi = 4;
RightSemi = 5;
LeftAnti = 6;
RightAnti = 7;
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -337,6 +338,27 @@ class CometSparkSessionExtensions
op
}

case op: ShuffledHashJoinExec
if isCometOperatorEnabled(conf, "hash_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometHashJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case c @ CoalesceExec(numPartitions, child)
if isCometOperatorEnabled(conf, "coalesce")
&& isCometNative(child) =>
Expand Down Expand Up @@ -576,7 +598,9 @@ object CometSparkSessionExtensions extends Logging {

private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = {
val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) &&
!conf.getConfString(operatorDisabledFlag, "false").toBoolean
}

private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
Expand Down
48 changes: 46 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
Expand All @@ -35,14 +36,15 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
import org.apache.comet.shims.ShimQueryPlanSerde

/**
Expand Down Expand Up @@ -1838,6 +1840,48 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
}
}

case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") =>
if (join.buildSide == BuildRight) {
// DataFusion HashJoin assumes build side is always left.
// TODO: support BuildRight
return None
}

val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
return None
}
condProto.get
}

val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ => return None // Spark doesn't support other join types
}

val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))

if (leftKeys.forall(_.isDefined) &&
rightKeys.forall(_.isDefined) &&
childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.HashJoin
.newBuilder()
.setJoinType(joinType)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
condition.foreach(joinBuilder.setCondition)
Some(result.setHashJoin(joinBuilder).build())
} else {
None
}

case op if isCometSink(op) =>
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
Expand Down
Loading

0 comments on commit 8078466

Please sign in to comment.