Skip to content

Commit ab2dcaa

Browse files
authored
fix: Fallback to Spark for unsupported input besides ordering (apache#768)
1 parent ffb96c3 commit ab2dcaa

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,7 +2926,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
29262926
case HashPartitioning(expressions, _) =>
29272927
val supported =
29282928
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2929-
expressions.forall(e => supportedDataType(e.dataType))
2929+
expressions.forall(e => supportedDataType(e.dataType)) &&
2930+
inputs.forall(attr => supportedDataType(attr.dataType))
29302931
if (!supported) {
29312932
msg = s"unsupported Spark partitioning expressions: $expressions"
29322933
}
@@ -2936,7 +2937,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
29362937
case RangePartitioning(orderings, _) =>
29372938
val supported =
29382939
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2939-
orderings.forall(e => supportedDataType(e.dataType))
2940+
orderings.forall(e => supportedDataType(e.dataType)) &&
2941+
inputs.forall(attr => supportedDataType(attr.dataType))
29402942
if (!supported) {
29412943
msg = s"unsupported Spark partitioning expressions: $orderings"
29422944
}
@@ -2975,7 +2977,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
29752977
case HashPartitioning(expressions, _) =>
29762978
val supported =
29772979
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2978-
expressions.forall(e => supportedDataType(e.dataType))
2980+
expressions.forall(e => supportedDataType(e.dataType)) &&
2981+
inputs.forall(attr => supportedDataType(attr.dataType))
29792982
if (!supported) {
29802983
msg = s"unsupported Spark partitioning expressions: $expressions"
29812984
}

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919

2020
package org.apache.comet.exec
2121

22+
import scala.util.Random
23+
2224
import org.scalactic.source.Position
2325
import org.scalatest.Tag
2426

2527
import org.apache.hadoop.fs.Path
2628
import org.apache.spark.{Partitioner, SparkConf}
27-
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
29+
import org.apache.spark.sql.{CometTestBase, DataFrame, RandomDataGenerator, Row}
2830
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager}
2931
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec}
3032
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
@@ -68,17 +70,35 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
6870

6971
test("Unsupported types for SinglePartition should fallback to Spark") {
7072
checkSparkAnswer(spark.sql("""
71-
|SELECT
72-
| AVG(null),
73-
| COUNT(null),
74-
| FIRST(null),
75-
| LAST(null),
76-
| MAX(null),
77-
| MIN(null),
78-
| SUM(null)
73+
|SELECT
74+
| AVG(null),
75+
| COUNT(null),
76+
| FIRST(null),
77+
| LAST(null),
78+
| MAX(null),
79+
| MIN(null),
80+
| SUM(null)
7981
""".stripMargin))
8082
}
8183

84+
test("Fallback to Spark for unsupported input besides ordering") {
85+
val dataGenerator = RandomDataGenerator
86+
.forType(
87+
dataType = NullType,
88+
nullable = true,
89+
new Random(System.nanoTime()),
90+
validJulianDatetime = false)
91+
.get
92+
93+
val schema = new StructType()
94+
.add("index", IntegerType, nullable = false)
95+
.add("col", NullType, nullable = true)
96+
val rdd =
97+
spark.sparkContext.parallelize((1 to 20).map(i => Row(i, dataGenerator())))
98+
val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
99+
checkSparkAnswer(df)
100+
}
101+
82102
test("Disable Comet shuffle with AQE coalesce partitions enabled") {
83103
Seq(true, false).foreach { coalescePartitionsEnabled =>
84104
withSQLConf(

0 commit comments

Comments
 (0)