Skip to content

Commit

Permalink
Close orc fetchOrcStatement and remove result save file when ExecuteS…
Browse files Browse the repository at this point in the history
…tatement close
  • Loading branch information
lsm1 committed Dec 7, 2023
1 parent 42634a1 commit 80e1f0d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.RejectedExecutionException
import scala.Array._
import scala.collection.JavaConverters._

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.kyuubi.SparkDatasetHelper._
import org.apache.spark.sql.types._
Expand All @@ -47,6 +48,8 @@ class ExecuteStatement(
override def getOperationLog: Option[OperationLog] = Option(operationLog)
override protected def supportProgress: Boolean = true

private var fetchOrcStatement: Option[FetchOrcStatement] = None
private var saveFileName: Option[String] = None
override protected def resultSchema: StructType = {
if (result == null || result.schema.isEmpty) {
new StructType().add("Result", "string")
Expand All @@ -65,6 +68,15 @@ class ExecuteStatement(
OperationLog.removeCurrentOperationLog()
}

override def close(): Unit = {
super.close()
fetchOrcStatement.foreach(_.close())
saveFileName.foreach { p =>
val path = new Path(p)
path.getFileSystem(spark.sparkContext.hadoopConfiguration).delete(path, true)
}
}

protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
resultDF.toLocalIterator().asScala
}
Expand Down Expand Up @@ -164,17 +176,18 @@ class ExecuteStatement(
if (hasResultSet && sparkSave && shouldSaveResultToHdfs(resultMaxRows, threshold, result)) {
val sessionId = session.handle.identifier.toString
val savePath = session.sessionManager.getConf.get(OPERATION_RESULT_SAVE_TO_FILE_PATH)
val fileName = s"$savePath/$engineId/$sessionId/$statementId"
saveFileName = Some(s"$savePath/$engineId/$sessionId/$statementId")
val colName = range(0, result.schema.size).map(x => "col" + x)
if (resultMaxRows > 0) {
result.toDF(colName: _*).limit(resultMaxRows).write
.option("compression", "zstd").format("orc").save(fileName)
.option("compression", "zstd").format("orc").save(saveFileName.get)
} else {
result.toDF(colName: _*).write
.option("compression", "zstd").format("orc").save(fileName)
.option("compression", "zstd").format("orc").save(saveFileName.get)
}
info(s"Save result to $fileName")
return new FetchOrcStatement(spark).getIterator(fileName, resultSchema)
info(s"Save result to $saveFileName")
fetchOrcStatement = Some(new FetchOrcStatement(spark))
return fetchOrcStatement.get.getIterator(saveFileName.get, resultSchema)
}
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ import org.apache.spark.sql.execution.datasources.orc.OrcDeserializer
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.KyuubiException
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.SPARK_ENGINE_RUNTIME_VERSION
import org.apache.kyuubi.operation.{FetchIterator, IterableFetchIterator}
import org.apache.kyuubi.util.reflect.DynConstructors

class FetchOrcStatement(spark: SparkSession) {

var orcIter: OrcFileIterator = _
def getIterator(path: String, orcSchema: StructType): FetchIterator[Row] = {
val conf = spark.sparkContext.hadoopConfiguration
val savePath = new Path(path)
Expand All @@ -59,21 +63,42 @@ class FetchOrcStatement(spark: SparkSession) {
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
val deserializer = getOrcDeserializer(orcSchema, colId)
val iter = new OrcFileIterator(list)
val iterRow = iter.map(value =>
orcIter = new OrcFileIterator(list)
val iterRow = orcIter.map(value =>
unsafeProjection(deserializer.deserialize(value)))
.map(value => toRowConverter(value))
new IterableFetchIterator[Row](iterRow.toIterable)
}

def close(): Unit = {
orcIter.close()
}

private def getOrcDeserializer(orcSchema: StructType, colId: Array[Int]): OrcDeserializer = {
try {
val cls = Class.forName("org.apache.spark.sql.execution.datasources.orc.OrcDeserializer")
val constructor = cls.getDeclaredConstructors.apply(0)
if (constructor.getParameterCount == 3) {
constructor.newInstance(new StructType, orcSchema, colId).asInstanceOf[OrcDeserializer]
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
// https://issues.apache.org/jira/browse/SPARK-34535
DynConstructors.builder()
.impl(
classOf[OrcDeserializer],
classOf[StructType],
classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
orcSchema,
colId)
} else {
constructor.newInstance(orcSchema, colId).asInstanceOf[OrcDeserializer]
DynConstructors.builder()
.impl(
classOf[OrcDeserializer],
classOf[StructType],
classOf[StructType],
classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
new StructType,
orcSchema,
colId)
}
} catch {
case e: Throwable =>
Expand All @@ -84,7 +109,7 @@ class FetchOrcStatement(spark: SparkSession) {

class OrcFileIterator(fileList: ListBuffer[LocatedFileStatus]) extends Iterator[OrcStruct] {

val iters = fileList.map(x => getOrcFileIterator(x))
private val iters = fileList.map(x => getOrcFileIterator(x))

var idx = 0

Expand All @@ -106,6 +131,10 @@ class OrcFileIterator(fileList: ListBuffer[LocatedFileStatus]) extends Iterator[
}
}

def close(): Unit = {
iters.foreach(_.close())
}

private def getOrcFileIterator(file: LocatedFileStatus): RecordReaderIterator[OrcStruct] = {
val orcRecordReader = {
val split =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.HiveResult
Expand Down Expand Up @@ -297,8 +298,10 @@ object SparkDatasetHelper extends Logging {
}

def shouldSaveResultToHdfs(resultMaxRows: Int, threshold: Int, result: DataFrame): Boolean = {
if (isCommandExec(result.queryExecution.executedPlan.nodeName)) {
return false
}
lazy val limit = result.queryExecution.executedPlan match {
case plan if isCommandExec(plan.nodeName) => 0
case collectLimit: CollectLimitExec => collectLimit.limit
case _ => resultMaxRows
}
Expand All @@ -308,17 +311,13 @@ object SparkDatasetHelper extends Logging {
} else {
result.queryExecution.optimizedPlan.stats.sizeInBytes
}
lazy val isSort = result.queryExecution.sparkPlan match {
case s: SortExec => s.global
case _ => false
}
lazy val colSize =
if (result == null || result.schema.isEmpty) {
0
} else {
result.schema.size
}
threshold > 0 && colSize > 0 && !isSort && stats >= threshold
threshold > 0 && colSize > 0 && stats >= threshold
}

private def isCommandExec(nodeName: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,7 @@ object KyuubiConf {
.doc("The threshold of Spark result save to hdfs file, default value is 200 MB")
.version("1.9.0")
.intConf
.checkValue(_ > 0, "must be positive value")
.createWithDefault(209715200)

val OPERATION_INCREMENTAL_COLLECT: ConfigEntry[Boolean] =
Expand Down

0 comments on commit 80e1f0d

Please sign in to comment.