Skip to content

Commit 639fa2f

Browse files
authored
minor: refactor decodeBatches to make private in broadcast exchange (#1195)
1 parent 053b7cc commit 639fa2f

File tree

3 files changed

+25
-67
lines changed

3 files changed

+25
-67
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,24 @@
1919

2020
package org.apache.spark.sql.comet
2121

22+
import java.io.DataInputStream
23+
import java.nio.channels.Channels
2224
import java.util.UUID
2325
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
2426

2527
import scala.concurrent.{ExecutionContext, Promise}
2628
import scala.concurrent.duration.NANOSECONDS
2729
import scala.util.control.NonFatal
2830

29-
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
31+
import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext}
3032
import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
33+
import org.apache.spark.io.CompressionCodec
3134
import org.apache.spark.launcher.SparkLauncher
3235
import org.apache.spark.rdd.RDD
3336
import org.apache.spark.sql.catalyst.InternalRow
3437
import org.apache.spark.sql.catalyst.expressions.Attribute
3538
import org.apache.spark.sql.catalyst.plans.logical.Statistics
39+
import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
3640
import org.apache.spark.sql.errors.QueryExecutionErrors
3741
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
3842
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
@@ -299,7 +303,23 @@ class CometBatchRDD(
299303
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
300304
val partition = split.asInstanceOf[CometBatchPartition]
301305
partition.value.value.toIterator
302-
.flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
306+
.flatMap(decodeBatches(_, this.getClass.getSimpleName))
307+
}
308+
309+
/**
310+
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
311+
*/
312+
private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
313+
if (bytes.size == 0) {
314+
return Iterator.empty
315+
}
316+
317+
// use Spark's compression codec (LZ4 by default) and not Comet's compression
318+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
319+
val cbbis = bytes.toInputStream()
320+
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
321+
// batches are in Arrow IPC format
322+
new ArrowReaderIterator(Channels.newChannel(ins), source)
303323
}
304324
}
305325

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,20 @@
1919

2020
package org.apache.spark.sql.comet
2121

22-
import java.io.{ByteArrayOutputStream, DataInputStream}
23-
import java.nio.channels.Channels
22+
import java.io.ByteArrayOutputStream
2423

2524
import scala.collection.mutable
2625
import scala.collection.mutable.ArrayBuffer
2726

28-
import org.apache.spark.{SparkEnv, TaskContext}
29-
import org.apache.spark.io.CompressionCodec
27+
import org.apache.spark.TaskContext
3028
import org.apache.spark.rdd.RDD
3129
import org.apache.spark.sql.catalyst.InternalRow
3230
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
3331
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
3432
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
3533
import org.apache.spark.sql.catalyst.plans._
3634
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
37-
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
35+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3836
import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
3937
import org.apache.spark.sql.comet.util.Utils
4038
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
@@ -78,18 +76,6 @@ abstract class CometExec extends CometPlan {
7876
// outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
7977
override def outputPartitioning: Partitioning = originalPlan.outputPartitioning
8078

81-
/**
82-
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
83-
*/
84-
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
85-
val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
86-
val total = countsAndBytes.map(_._1).sum
87-
val rows = countsAndBytes.iterator
88-
.flatMap(countAndBytes =>
89-
CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
90-
(total, rows)
91-
}
92-
9379
protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
9480
sparkPlan.children.foreach(setSubqueries(planId, _))
9581

@@ -161,21 +147,6 @@ object CometExec {
161147
Utils.serializeBatches(iter)
162148
}
163149
}
164-
165-
/**
166-
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
167-
*/
168-
def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
169-
if (bytes.size == 0) {
170-
return Iterator.empty
171-
}
172-
173-
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
174-
val cbbis = bytes.toInputStream()
175-
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
176-
177-
new ArrowReaderIterator(Channels.newChannel(ins), source)
178-
}
179150
}
180151

181152
/**

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ package org.apache.comet.exec
2222
import java.sql.Date
2323
import java.time.{Duration, Period}
2424

25-
import scala.collection.JavaConverters._
26-
import scala.collection.mutable
2725
import scala.util.Random
2826

2927
import org.scalactic.source.Position
@@ -462,37 +460,6 @@ class CometExecSuite extends CometTestBase {
462460
}
463461
}
464462

465-
test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") {
466-
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
467-
withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") {
468-
withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") {
469-
val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3")
470-
471-
val nativeProject = find(df.queryExecution.executedPlan) {
472-
case _: CometProjectExec => true
473-
case _ => false
474-
}.get.asInstanceOf[CometProjectExec]
475-
476-
val (rows, batches) = nativeProject.executeColumnarCollectIterator()
477-
assert(rows == 46)
478-
479-
val column1 = mutable.ArrayBuffer.empty[Int]
480-
val column2 = mutable.ArrayBuffer.empty[Int]
481-
482-
batches.foreach(batch => {
483-
batch.rowIterator().asScala.foreach { row =>
484-
assert(row.numFields == 2)
485-
column1 += row.getInt(0)
486-
column2 += row.getInt(1)
487-
}
488-
})
489-
490-
assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray)
491-
assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray)
492-
}
493-
}
494-
}
495-
496463
test("scalar subquery") {
497464
val dataTypes =
498465
Seq(

0 commit comments

Comments
 (0)