Skip to content
Merged
48 changes: 37 additions & 11 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctio

use crate::execution::shuffle::CompressionCodec;
use crate::execution::spark_plan::SparkPlan;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion_comet_proto::{
spark_expression::{
self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr,
Expand Down Expand Up @@ -1183,17 +1184,42 @@ impl PhysicalPlanner {
false,
)?);

Ok((
scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
join,
vec![
Arc::clone(&join_params.left),
Arc::clone(&join_params.right),
],
)),
))
if join.filter.is_some() {
// SMJ with join filter produces lots of tiny batches
let coalesce_batches: Arc<dyn ExecutionPlan> =
Arc::new(CoalesceBatchesExec::new(
Arc::<SortMergeJoinExec>::clone(&join),
self.session_ctx
.state()
.config_options()
.execution
.batch_size,
));
Ok((
scans,
Arc::new(SparkPlan::new_with_additional(
spark_plan.plan_id,
coalesce_batches,
vec![
Arc::clone(&join_params.left),
Arc::clone(&join_params.right),
],
vec![join],
)),
))
} else {
Ok((
scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
join,
vec![
Arc::clone(&join_params.left),
Arc::clone(&join_params.right),
],
)),
))
}
}
OpStruct::HashJoin(join) => {
let (join_params, scans) = self.parse_join_parameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class CometSparkSessionExtensions

case op: SortMergeJoinExec
if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
op.children.forall(isCometNative(_)) =>
op.children.forall(isCometNative) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2859,11 +2859,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
// TODO: DF SMJ with join condition fails TPCH q21
case LeftAnti if condition.isEmpty => JoinType.LeftAnti
case LeftAnti =>
withInfo(join, "LeftAnti SMJ join with condition is not supported")
return None
case LeftAnti => JoinType.LeftAnti
case _ =>
// Spark doesn't support other join types
withInfo(op, s"Unsupported join type ${join.joinType}")
Expand Down
11 changes: 4 additions & 7 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class CometJoinSuite extends CometTestBase {
v.toDouble,
v.toString,
v % 2 == 0,
v.toString().getBytes,
v.toString.getBytes,
Decimal(v))

withParquetTable((0 until 10).map(i => manyTypes(i, i % 5)), "tbl_a") {
Expand Down Expand Up @@ -294,6 +294,7 @@ class CometJoinSuite extends CometTestBase {

test("SortMergeJoin without join filter") {
withSQLConf(
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
Expand Down Expand Up @@ -338,9 +339,9 @@ class CometJoinSuite extends CometTestBase {
}
}

// https://github.com/apache/datafusion-comet/issues/398
ignore("SortMergeJoin with join filter") {
test("SortMergeJoin with join filter") {
withSQLConf(
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true",
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Expand Down Expand Up @@ -391,9 +392,6 @@ class CometJoinSuite extends CometTestBase {
"AND tbl_a._2 >= tbl_b._1")
checkSparkAnswerAndOperator(df9)

// TODO: Enable these tests after fixing the issue:
// https://github.com/apache/datafusion-comet/issues/861
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not just remove the comment. This test is also ignored. We need to re-enable it too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, thanks @viirya

val df10 = sql(
"SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._2 >= tbl_b._1")
Expand All @@ -403,7 +401,6 @@ class CometJoinSuite extends CometTestBase {
"SELECT * FROM tbl_b LEFT ANTI JOIN tbl_a ON tbl_a._2 = tbl_b._1 " +
"AND tbl_a._2 >= tbl_b._1")
checkSparkAnswerAndOperator(df11)
*/
}
}
}
Expand Down
Loading