Skip to content

Commit e3fdff6

Browse files
committed
Add CometColumnarToRowExec
1 parent 7c4a8a0 commit e3fdff6

File tree

4 files changed

+208
-4
lines changed

4 files changed

+208
-4
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,10 @@ class CometSparkSessionExtensions
10721072
case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] {
10731073
override def apply(plan: SparkPlan): SparkPlan = {
10741074
val eliminatedPlan = plan transformUp {
1075+
case ColumnarToRowExec(child) => CometColumnarToRowExec(child)
10751076
case ColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child
1077+
case CometColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) =>
1078+
sparkToColumnar.child
10761079
case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child
10771080
// Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the
10781081
// shuffle takes row-based input.
@@ -1089,6 +1092,8 @@ class CometSparkSessionExtensions
10891092
eliminatedPlan match {
10901093
case ColumnarToRowExec(child: CometCollectLimitExec) =>
10911094
child
1095+
case CometColumnarToRowExec(child: CometCollectLimitExec) =>
1096+
child
10921097
case other =>
10931098
other
10941099
}
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet
21+
22+
import scala.collection.JavaConverters._
23+
24+
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection}
27+
import org.apache.spark.sql.catalyst.expressions.codegen._
28+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
29+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
30+
import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan}
31+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
32+
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
33+
import org.apache.spark.sql.types._
34+
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
35+
import org.apache.spark.util.Utils
36+
37+
/**
38+
* Copied from Spark `ColumnarToRowExec`. Comet needs the fix for SPARK-50235 but cannot wait for
39+
* the fix to be released in Spark versions. We copy the implementation here to apply the fix.
40+
*/
41+
case class CometColumnarToRowExec(child: SparkPlan)
42+
extends ColumnarToRowTransition
43+
with CodegenSupport {
44+
// supportsColumnar requires to be only called on driver side, see also SPARK-37779.
45+
assert(Utils.isInRunningSparkTask || child.supportsColumnar)
46+
47+
override def output: Seq[Attribute] = child.output
48+
49+
override def outputPartitioning: Partitioning = child.outputPartitioning
50+
51+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
52+
53+
// `ColumnarToRowExec` processes the input RDD directly, which is kind of a leaf node in the
54+
// codegen stage and needs to do the limit check.
55+
protected override def canCheckLimitNotReached: Boolean = true
56+
57+
override lazy val metrics: Map[String, SQLMetric] = Map(
58+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
59+
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"))
60+
61+
override def doExecute(): RDD[InternalRow] = {
62+
val numOutputRows = longMetric("numOutputRows")
63+
val numInputBatches = longMetric("numInputBatches")
64+
// This avoids calling `output` in the RDD closure, so that we don't need to include the entire
65+
// plan (this) in the closure.
66+
val localOutput = this.output
67+
child.executeColumnar().mapPartitionsInternal { batches =>
68+
val toUnsafe = UnsafeProjection.create(localOutput, localOutput)
69+
batches.flatMap { batch =>
70+
numInputBatches += 1
71+
numOutputRows += batch.numRows()
72+
batch.rowIterator().asScala.map(toUnsafe)
73+
}
74+
}
75+
}
76+
77+
/**
78+
* Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once
79+
* per [[ColumnVector]] in the batch.
80+
*/
81+
private def genCodeColumnVector(
82+
ctx: CodegenContext,
83+
columnVar: String,
84+
ordinal: String,
85+
dataType: DataType,
86+
nullable: Boolean): ExprCode = {
87+
val javaType = CodeGenerator.javaType(dataType)
88+
val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
89+
val isNullVar = if (nullable) {
90+
JavaCode.isNullVariable(ctx.freshName("isNull"))
91+
} else {
92+
FalseLiteral
93+
}
94+
val valueVar = ctx.freshName("value")
95+
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
96+
val code = code"${ctx.registerComment(str)}" + (if (nullable) {
97+
code"""
98+
boolean $isNullVar = $columnVar.isNullAt($ordinal);
99+
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
100+
"""
101+
} else {
102+
code"$javaType $valueVar = $value;"
103+
})
104+
ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
105+
}
106+
107+
/**
108+
* Produce code to process the input iterator as [[ColumnarBatch]]es. This produces an
109+
* [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] for each row in each batch.
110+
*/
111+
override protected def doProduce(ctx: CodegenContext): String = {
112+
// PhysicalRDD always just has one input
113+
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
114+
115+
// metrics
116+
val numOutputRows = metricTerm(ctx, "numOutputRows")
117+
val numInputBatches = metricTerm(ctx, "numInputBatches")
118+
119+
val columnarBatchClz = classOf[ColumnarBatch].getName
120+
val batch = ctx.addMutableState(columnarBatchClz, "batch")
121+
122+
val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
123+
val columnVectorClzs =
124+
child.vectorTypes.getOrElse(Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
125+
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
126+
case (columnVectorClz, i) =>
127+
val name = ctx.addMutableState(columnVectorClz, s"colInstance$i")
128+
(name, s"$name = ($columnVectorClz) $batch.column($i);")
129+
}.unzip
130+
131+
val nextBatch = ctx.freshName("nextBatch")
132+
val nextBatchFuncName = ctx.addNewFunction(
133+
nextBatch,
134+
s"""
135+
|private void $nextBatch() throws java.io.IOException {
136+
| if ($input.hasNext()) {
137+
| $batch = ($columnarBatchClz)$input.next();
138+
| $numInputBatches.add(1);
139+
| $numOutputRows.add($batch.numRows());
140+
| $idx = 0;
141+
| ${columnAssigns.mkString("", "\n", "\n")}
142+
| }
143+
|}""".stripMargin)
144+
145+
ctx.currentVars = null
146+
val rowidx = ctx.freshName("rowIdx")
147+
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
148+
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
149+
}
150+
val localIdx = ctx.freshName("localIdx")
151+
val localEnd = ctx.freshName("localEnd")
152+
val numRows = ctx.freshName("numRows")
153+
val shouldStop = if (parent.needStopCheck) {
154+
s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
155+
} else {
156+
"// shouldStop check is eliminated"
157+
}
158+
159+
val writableColumnVectorClz = classOf[WritableColumnVector].getName
160+
161+
s"""
162+
|if ($batch == null) {
163+
| $nextBatchFuncName();
164+
|}
165+
|while ($limitNotReachedCond $batch != null) {
166+
| int $numRows = $batch.numRows();
167+
| int $localEnd = $numRows - $idx;
168+
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
169+
| int $rowidx = $idx + $localIdx;
170+
| ${consume(ctx, columnsBatchInput).trim}
171+
| $shouldStop
172+
| }
173+
| $idx = $numRows;
174+
|
175+
| // Comet fix for SPARK-50235
176+
| for (int i = 0; i < ${colVars.length}; i++) {
177+
| if (!($batch.column(i) instanceof $writableColumnVectorClz)) {
178+
| $batch.column(i).close();
179+
| }
180+
| }
181+
|
182+
| $batch = null;
183+
| $nextBatchFuncName();
184+
|}
185+
|// Comet fix for SPARK-50235: clean up resources
186+
|if ($batch != null) {
187+
| $batch.close();
188+
|}
189+
""".stripMargin
190+
}
191+
192+
override def inputRDDs(): Seq[RDD[InternalRow]] = {
193+
Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure
194+
}
195+
196+
override protected def withNewChildInternal(newChild: SparkPlan): CometColumnarToRowExec =
197+
copy(child = newChild)
198+
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ import scala.util.Random
2828
import org.apache.hadoop.fs.Path
2929
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
3030
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
31-
import org.apache.spark.sql.comet.CometProjectExec
32-
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, ProjectExec, WholeStageCodegenExec}
31+
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec}
32+
import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, WholeStageCodegenExec}
3333
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3434
import org.apache.spark.sql.functions._
3535
import org.apache.spark.sql.internal.SQLConf
@@ -752,7 +752,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
752752
val project = cometPlan
753753
.asInstanceOf[WholeStageCodegenExec]
754754
.child
755-
.asInstanceOf[ColumnarToRowExec]
755+
.asInstanceOf[CometColumnarToRowExec]
756756
.child
757757
.asInstanceOf[InputAdapter]
758758
.child

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter
3636
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
3737
import org.apache.spark._
3838
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER}
39-
import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec}
39+
import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometColumnarToRowExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec}
4040
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
4141
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec}
4242
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -174,6 +174,7 @@ abstract class CometTestBase
174174
wrapped.foreach {
175175
case _: CometScanExec | _: CometBatchScanExec =>
176176
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
177+
case _: CometColumnarToRowExec =>
177178
case _: CometSparkToColumnarExec =>
178179
case _: CometExec | _: CometShuffleExchangeExec =>
179180
case _: CometBroadcastExchangeExec =>

0 commit comments

Comments
 (0)