diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index ef2787f83..01c79e0c0 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -17,7 +17,7 @@ //! Converts Spark physical plan to DataFusion physical plan -use std::{str::FromStr, sync::Arc}; +use std::{collections::HashMap, str::FromStr, sync::Arc}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use datafusion::{ @@ -37,13 +37,17 @@ use datafusion::{ physical_plan::{ aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy}, filter::FilterExec, + joins::{utils::JoinFilter, HashJoinExec, PartitionMode}, limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, ExecutionPlan, Partitioning, }, }; -use datafusion_common::ScalarValue; +use datafusion_common::{ + tree_node::{TreeNode, VisitRecursion}, + JoinType as DFJoinType, ScalarValue, +}; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; @@ -76,7 +80,7 @@ use crate::{ agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr, ScalarFunc, }, - spark_operator::{operator::OpStruct, Operator}, + spark_operator::{operator::OpStruct, JoinType, Operator}, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; @@ -858,6 +862,107 @@ impl PhysicalPlanner { Arc::new(CometExpandExec::new(projections, child, schema)), )) } + OpStruct::HashJoin(join) => { + assert!(children.len() == 2); + let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; + let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; + + left_scans.append(&mut right_scans); + + let left_join_exprs: Vec<_> = join + .left_join_keys + .iter() + .map(|expr| self.create_expr(expr, left.schema())) + .collect::, _>>()?; + let right_join_exprs: Vec<_> = join + .right_join_keys + .iter() + .map(|expr| self.create_expr(expr, right.schema())) + .collect::, _>>()?; + + let join_on = left_join_exprs + .into_iter() + .zip(right_join_exprs) + .collect::>(); + + let join_type = match join.join_type.try_into() { + Ok(JoinType::Inner) => DFJoinType::Inner, + Ok(JoinType::LeftOuter) => DFJoinType::Left, + Ok(JoinType::RightOuter) => DFJoinType::Right, + Ok(JoinType::FullOuter) => DFJoinType::Full, + Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, + Ok(JoinType::RightSemi) => DFJoinType::RightSemi, + Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, + Ok(JoinType::RightAnti) => DFJoinType::RightAnti, + Err(_) => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported join type: {:?}", + join.join_type + ))); + } + }; + + // Handle join filter as DataFusion `JoinFilter` struct + let join_filter = if let Some(expr) = &join.condition { + let physical_expr = self.create_expr(expr, left.schema())?; + let (left_field_indices, right_field_indices) = expr_to_columns( + &physical_expr, + left.schema().fields.len(), + right.schema().fields.len(), + )?; + let column_indices = JoinFilter::build_column_indices( + left_field_indices.clone(), + right_field_indices.clone(), + ); + + let filter_fields: Vec = left_field_indices + .into_iter() + .map(|i| left.schema().field(i).clone()) + .chain( + right_field_indices + .into_iter() + .map(|i| right.schema().field(i).clone()), + ) + .collect_vec(); + + let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); + + Some(JoinFilter::new( + physical_expr, + column_indices, + filter_schema, + )) + } else { + None + }; + + // DataFusion `HashJoinExec` operator keeps the input batch internally. We need + // to copy the input batch to avoid the data corruption from reusing the input + // batch. + let left = if !is_op_do_copying(&left) { + Arc::new(CopyExec::new(left)) + } else { + left + }; + + let right = if !is_op_do_copying(&right) { + Arc::new(CopyExec::new(right)) + } else { + right + }; + + let join = Arc::new(HashJoinExec::try_new( + left, + right, + join_on, + join_filter, + &join_type, + PartitionMode::Partitioned, + false, + )?); + + Ok((left_scans, join)) + } } } @@ -1026,6 +1131,48 @@ impl From for DataFusionError { } } +/// Returns true if given operator copies input batch to avoid data corruption from reusing +/// input arrays. +fn is_op_do_copying(op: &Arc) -> bool { + op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() + || op.as_any().downcast_ref::().is_some() +} + +/// Collects the indices of the columns in the input schema that are used in the expression +/// and returns them as a pair of vectors, one for the left side and one for the right side. +fn expr_to_columns( + expr: &Arc, + left_field_len: usize, + right_field_len: usize, +) -> Result<(Vec, Vec), ExecutionError> { + let mut left_field_indices: Vec = vec![]; + let mut right_field_indices: Vec = vec![]; + + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + if column.index() > left_field_len + right_field_len { + return Err(DataFusionError::Internal(format!( + "Column index {} out of range", + column.index() + ))); + } else if column.index() < left_field_len { + left_field_indices.push(column.index()); + } else { + right_field_indices.push(column.index() - left_field_len); + } + } + VisitRecursion::Continue + }) + })?; + + left_field_indices.sort(); + right_field_indices.sort(); + + Ok((left_field_indices, right_field_indices)) +} + #[cfg(test)] mod tests { use std::{sync::Arc, task::Poll}; diff --git a/core/src/execution/operators/copy.rs b/core/src/execution/operators/copy.rs index 996db2b47..699ccf7ae 100644 --- a/core/src/execution/operators/copy.rs +++ b/core/src/execution/operators/copy.rs @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec { } fn children(&self) -> Vec> { - self.input.children() + vec![self.input.clone()] } fn with_new_children( diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index 5b07cb30b..ce58edd0f 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -40,6 +40,7 @@ message Operator { Limit limit = 105; ShuffleWriter shuffle_writer = 106; Expand expand = 107; + HashJoin hash_join = 108; } } @@ -87,3 +88,21 @@ message Expand { repeated spark.spark_expression.Expr project_list = 1; int32 num_expr_per_project = 3; } + +message HashJoin { + repeated spark.spark_expression.Expr left_join_keys = 1; + repeated spark.spark_expression.Expr right_join_keys = 2; + JoinType join_type = 3; + optional spark.spark_expression.Expr condition = 4; +} + +enum JoinType { + Inner = 0; + LeftOuter = 1; + RightOuter = 2; + FullOuter = 3; + LeftSemi = 4; + RightSemi = 5; + LeftAnti = 6; + RightAnti = 7; +} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 39c83ae53..7eca99557 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -337,6 +338,27 @@ class CometSparkSessionExtensions op } + case op: ShuffledHashJoinExec + if isCometOperatorEnabled(conf, "hash_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometHashJoinExec( + nativeOp, + op, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case c @ CoalesceExec(numPartitions, child) if isCometOperatorEnabled(conf, "coalesce") && isCometNative(child) => @@ -576,7 +598,9 @@ object CometSparkSessionExtensions extends Logging { private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = { val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled" - conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) + val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled" + conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) && + !conf.getConfString(operatorDisabledFlag, "false").toBoolean } private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5da926e38..b0692ad8e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -25,7 +25,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} @@ -35,6 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +44,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} import org.apache.comet.shims.ShimQueryPlanSerde /** @@ -1838,6 +1840,48 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } + case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") => + if (join.buildSide == BuildRight) { + // DataFusion HashJoin assumes build side is always left. + // TODO: support BuildRight + return None + } + + val condition = join.condition.map { cond => + val condProto = exprToProto(cond, join.left.output ++ join.right.output) + if (condProto.isEmpty) { + return None + } + condProto.get + } + + val joinType = join.joinType match { + case Inner => JoinType.Inner + case LeftOuter => JoinType.LeftOuter + case RightOuter => JoinType.RightOuter + case FullOuter => JoinType.FullOuter + case LeftSemi => JoinType.LeftSemi + case LeftAnti => JoinType.LeftAnti + case _ => return None // Spark doesn't support other join types + } + + val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) + val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + + if (leftKeys.forall(_.isDefined) && + rightKeys.forall(_.isDefined) && + childOp.nonEmpty) { + val joinBuilder = OperatorOuterClass.HashJoin + .newBuilder() + .setJoinType(joinType) + .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) + .addAllRightJoinKeys(rightKeys.map(_.get).asJava) + condition.foreach(joinBuilder.setCondition) + Some(result.setHashJoin(joinBuilder).build()) + } else { + None + } + case op if isCometSink(op) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 5551ffdbc..2d03b2e65 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -31,9 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -324,6 +326,8 @@ abstract class CometNativeExec extends CometExec { abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode +abstract class CometBinaryExec extends CometNativeExec with BinaryExecNode + /** * Represents the serialized plan of Comet native operators. Only the first operator in a block of * continuous Comet native operators has defined plan bytes which contains the serialization of @@ -584,6 +588,43 @@ case class CometHashAggregateExec( Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) } +case class CometHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometHashJoinExec => + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(leftKeys, rightKeys, condition, left, right) +} + case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 6a34d4fe4..c21fd109e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -58,6 +58,50 @@ class CometExecSuite extends CometTestBase { } } + test("HashJoin without join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df1) + + // TODO: Spark 3.4 returns SortMergeJoin for this query even with SHUFFLE_HASH hint. + // We need to investigate why this happens and fix it. + /* + val df2 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df2) + + val df3 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b LEFT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df3) + */ + + val df4 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df4) + + val df5 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b RIGHT JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df5) + + val df6 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df6) + + val df7 = + sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_b FULL JOIN tbl_a ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df7) + } + } + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))