Skip to content

Commit

Permalink
driver slice last batch
Browse files Browse the repository at this point in the history
  • Loading branch information
cfmcgrady committed Apr 4, 2023
1 parent a584943 commit 8593d85
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.kyuubi.engine.spark.operation
import java.util.concurrent.RejectedExecutionException

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution, TakeOrderedAndProjectExec}
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.types._
import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.spark.sql.execution.arrow.ArrowCollectLimitExec
import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils}

import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
Expand Down Expand Up @@ -210,10 +211,35 @@ class ArrowBasedExecuteStatement(
df.queryExecution.executedPlan.resetMetrics()
df.queryExecution.executedPlan match {
case collectLimit @ CollectLimitExec(limit, _) =>
// scalastyle:off
println("ddddd")
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId)
val batches = ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId)
// .map(_._1)
val result = ArrayBuffer[Array[Byte]]()
var i = 0
var rest = limit
println(s"batch....size... ${batches.length}")
while (i < batches.length && rest > 0) {
val (batch, size) = batches(i)
if (size < rest) {
result += batch
// TODO: toInt
rest = rest - size.toInt
} else if (size == rest) {
result += batch
rest = 0
} else { // size > rest
println(s"size......${size}....rest......${rest}")
// result += KyuubiArrowUtils.slice(batch, 0, rest)
result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest)
rest = 0
}
i += 1
}
result.toArray

case takeOrderedAndProjectExec @ TakeOrderedAndProjectExec(limit, _, _, _) =>
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
ArrowCollectLimitExec.taskOrdered(takeOrderedAndProjectExec, df.schema, 1000, 1024 * 1024, timeZoneId)
.map(_._1)
case _ =>
println("yyyy")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.execution.CollectLimitExec
import org.apache.spark.sql.execution.{CollectLimitExec, TakeOrderedAndProjectExec}
import org.apache.spark.sql.types.StructType

object ArrowCollectLimitExec extends SQLConfHelper {
Expand Down Expand Up @@ -124,4 +124,21 @@ object ArrowCollectLimitExec extends SQLConfHelper {
buf.toArray
}
}

def taskOrdered(
takeOrdered: TakeOrderedAndProjectExec,
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String): Array[Batch] = {
val batches = ArrowConvertersHelper.toBatchWithSchemaIterator(
takeOrdered.executeCollect().iterator,
schema,
maxEstimatedBatchSize,
maxEstimatedBatchSize,
takeOrdered.limit,
timeZoneId)
batches.map(b => b -> batches.rowCountInLastBatch).toArray
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.channels.Channels

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.MessageSerializer
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object KyuubiArrowUtils {
val rootAllocator = new RootAllocator(Long.MaxValue)
.newChildAllocator("ReadIntTest", 0, Long.MaxValue)
// BufferAllocator allocator =
// ArrowUtils.rootAllocator.newChildAllocator("ReadIntTest", 0, Long.MAX_VALUE);
def slice(bytes: Array[Byte], start: Int, length: Int): Array[Byte] = {
val in = new ByteArrayInputStream(bytes)
val out = new ByteArrayOutputStream()

var reader: ArrowStreamReader = null
try {
reader = new ArrowStreamReader(in, rootAllocator)
// reader.getVectorSchemaRoot.getSchema
reader.loadNextBatch()
val root = reader.getVectorSchemaRoot.slice(start, length)
// val loader = new VectorLoader(root)
val writer = new ArrowStreamWriter(root, null, out)
writer.start()
writer.writeBatch()
writer.end()
writer.close()
out.toByteArray
} finally {
if (reader != null) {
reader.close()
}
in.close()
out.close()
}
}

def sliceV2(schema: StructType,
timeZoneId: String, bytes: Array[Byte], start: Int, length: Int): Array[Byte] = {
val in = new ByteArrayInputStream(bytes)
val out = new ByteArrayOutputStream()

try {
// reader = new ArrowStreamReader(in, rootAllocator)
// // reader.getVectorSchemaRoot.getSchema
// reader.loadNextBatch()
// println("bytes......" + bytes.length)
// println("rowCount......" + reader.getVectorSchemaRoot.getRowCount)
// val root = reader.getVectorSchemaRoot.slice(start, length)


val recordBatch = MessageSerializer.deserializeRecordBatch(
new ReadChannel(Channels.newChannel(in)), rootAllocator)
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)

val root = VectorSchemaRoot.create(arrowSchema, rootAllocator)
val vectorLoader = new VectorLoader(root)
vectorLoader.load(recordBatch)
recordBatch.close()


val unloader = new VectorUnloader(root.slice(start, length))
val writeChannel = new WriteChannel(Channels.newChannel(out))
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)
batch.close()
out.toByteArray()
} finally {
in.close()
out.close()
}
}
}

0 comments on commit 8593d85

Please sign in to comment.