Skip to content

Commit 2e0f00a

Browse files
compheadandygrove
andauthored
feat: Reenable tests for filtered SMJ anti join (#1211)
* feat: reenable filtered SMJ Anti join tests * feat: reenable filtered SMJ Anti join tests * feat: reenable filtered SMJ Anti join tests * feat: reenable filtered SMJ Anti join tests * Add CoalesceBatchesExec around SMJ with join filter * adding `CoalesceBatches` * adding `CoalesceBatches` * adding `CoalesceBatches` * feat: reenable filtered SMJ Anti join tests * feat: reenable filtered SMJ Anti join tests --------- Co-authored-by: Andy Grove <agrove@apache.org>
1 parent 9320aed commit 2e0f00a

File tree

4 files changed

+43
-24
lines changed

4 files changed

+43
-24
lines changed

native/core/src/execution/planner.rs

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctio
7070

7171
use crate::execution::shuffle::CompressionCodec;
7272
use crate::execution::spark_plan::SparkPlan;
73+
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
7374
use datafusion_comet_proto::{
7475
spark_expression::{
7576
self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr,
@@ -1183,17 +1184,42 @@ impl PhysicalPlanner {
11831184
false,
11841185
)?);
11851186

1186-
Ok((
1187-
scans,
1188-
Arc::new(SparkPlan::new(
1189-
spark_plan.plan_id,
1190-
join,
1191-
vec![
1192-
Arc::clone(&join_params.left),
1193-
Arc::clone(&join_params.right),
1194-
],
1195-
)),
1196-
))
1187+
if join.filter.is_some() {
1188+
// SMJ with join filter produces lots of tiny batches
1189+
let coalesce_batches: Arc<dyn ExecutionPlan> =
1190+
Arc::new(CoalesceBatchesExec::new(
1191+
Arc::<SortMergeJoinExec>::clone(&join),
1192+
self.session_ctx
1193+
.state()
1194+
.config_options()
1195+
.execution
1196+
.batch_size,
1197+
));
1198+
Ok((
1199+
scans,
1200+
Arc::new(SparkPlan::new_with_additional(
1201+
spark_plan.plan_id,
1202+
coalesce_batches,
1203+
vec![
1204+
Arc::clone(&join_params.left),
1205+
Arc::clone(&join_params.right),
1206+
],
1207+
vec![join],
1208+
)),
1209+
))
1210+
} else {
1211+
Ok((
1212+
scans,
1213+
Arc::new(SparkPlan::new(
1214+
spark_plan.plan_id,
1215+
join,
1216+
vec![
1217+
Arc::clone(&join_params.left),
1218+
Arc::clone(&join_params.right),
1219+
],
1220+
)),
1221+
))
1222+
}
11971223
}
11981224
OpStruct::HashJoin(join) => {
11991225
let (join_params, scans) = self.parse_join_parameters(

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ class CometSparkSessionExtensions
567567

568568
case op: SortMergeJoinExec
569569
if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
570-
op.children.forall(isCometNative(_)) =>
570+
op.children.forall(isCometNative) =>
571571
val newOp = transform1(op)
572572
newOp match {
573573
case Some(nativeOp) =>

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,11 +2859,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
28592859
case RightOuter => JoinType.RightOuter
28602860
case FullOuter => JoinType.FullOuter
28612861
case LeftSemi => JoinType.LeftSemi
2862-
// TODO: DF SMJ with join condition fails TPCH q21
2863-
case LeftAnti if condition.isEmpty => JoinType.LeftAnti
2864-
case LeftAnti =>
2865-
withInfo(join, "LeftAnti SMJ join with condition is not supported")
2866-
return None
2862+
case LeftAnti => JoinType.LeftAnti
28672863
case _ =>
28682864
// Spark doesn't support other join types
28692865
withInfo(op, s"Unsupported join type ${join.joinType}")

spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class CometJoinSuite extends CometTestBase {
216216
v.toDouble,
217217
v.toString,
218218
v % 2 == 0,
219-
v.toString().getBytes,
219+
v.toString.getBytes,
220220
Decimal(v))
221221

222222
withParquetTable((0 until 10).map(i => manyTypes(i, i % 5)), "tbl_a") {
@@ -294,6 +294,7 @@ class CometJoinSuite extends CometTestBase {
294294

295295
test("SortMergeJoin without join filter") {
296296
withSQLConf(
297+
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true",
297298
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
298299
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
299300
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
@@ -338,9 +339,9 @@ class CometJoinSuite extends CometTestBase {
338339
}
339340
}
340341

341-
// https://github.com/apache/datafusion-comet/issues/398
342-
ignore("SortMergeJoin with join filter") {
342+
test("SortMergeJoin with join filter") {
343343
withSQLConf(
344+
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true",
344345
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true",
345346
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
346347
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
@@ -391,9 +392,6 @@ class CometJoinSuite extends CometTestBase {
391392
"AND tbl_a._2 >= tbl_b._1")
392393
checkSparkAnswerAndOperator(df9)
393394

394-
// TODO: Enable these tests after fixing the issue:
395-
// https://github.com/apache/datafusion-comet/issues/861
396-
/*
397395
val df10 = sql(
398396
"SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b ON tbl_a._2 = tbl_b._1 " +
399397
"AND tbl_a._2 >= tbl_b._1")
@@ -403,7 +401,6 @@ class CometJoinSuite extends CometTestBase {
403401
"SELECT * FROM tbl_b LEFT ANTI JOIN tbl_a ON tbl_a._2 = tbl_b._1 " +
404402
"AND tbl_a._2 >= tbl_b._1")
405403
checkSparkAnswerAndOperator(df11)
406-
*/
407404
}
408405
}
409406
}

0 commit comments

Comments
 (0)