Skip to content

[SPARK] Reuse time formatters instance in value serialization of TRowSet generation #5811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ import scala.collection.JavaConverters._
import org.apache.hive.service.rpc.thrift._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.TimeFormatters
import org.apache.spark.sql.types._

import org.apache.kyuubi.util.RowSetUtils._

object RowSet {

def toHiveString(valueAndType: (Any, DataType), nested: Boolean = false): String = {
// compatible w/ Spark 3.1 and above
val timeFormatters = HiveResult.getTimeFormatters
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS: The key change of this PR, skipping time formatters recreation.

Copy link
Contributor

@cfmcgrady cfmcgrady Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason that the time formatter is created each time is because the session time zone can be modified by querying set spark.sql.session.timeZone=<new time zone>

def toHiveString(
valueAndType: (Any, DataType),
nested: Boolean = false,
timeFormatters: TimeFormatters): String = {
HiveResult.toHiveString(valueAndType, nested, timeFormatters)
}

Expand Down Expand Up @@ -71,14 +73,15 @@ object RowSet {
def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
val timeFormatters = HiveResult.getTimeFormatters
var i = 0
while (i < rowSize) {
val row = rows(i)
val tRow = new TRow()
var j = 0
val columnSize = row.length
while (j < columnSize) {
val columnValue = toTColumnValue(j, row, schema)
val columnValue = toTColumnValue(j, row, schema, timeFormatters)
tRow.addToColVals(columnValue)
j += 1
}
Expand All @@ -91,18 +94,23 @@ object RowSet {
def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
val timeFormatters = HiveResult.getTimeFormatters
var i = 0
val columnSize = schema.length
while (i < columnSize) {
val field = schema(i)
val tColumn = toTColumn(rows, i, field.dataType)
val tColumn = toTColumn(rows, i, field.dataType, timeFormatters)
tRowSet.addToColumns(tColumn)
i += 1
}
tRowSet
}

private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
private def toTColumn(
rows: Seq[Row],
ordinal: Int,
typ: DataType,
timeFormatters: TimeFormatters): TColumn = {
val nulls = new java.util.BitSet()
typ match {
case BooleanType =>
Expand Down Expand Up @@ -152,7 +160,7 @@ object RowSet {
while (i < rowSize) {
val row = rows(i)
nulls.set(i, row.isNullAt(ordinal))
values.add(toHiveString(row.get(ordinal) -> typ))
values.add(toHiveString(row.get(ordinal) -> typ, timeFormatters = timeFormatters))
i += 1
}
TColumn.stringVal(new TStringColumn(values, nulls))
Expand Down Expand Up @@ -184,7 +192,8 @@ object RowSet {
private def toTColumnValue(
ordinal: Int,
row: Row,
types: StructType): TColumnValue = {
types: StructType,
timeFormatters: TimeFormatters): TColumnValue = {
types(ordinal).dataType match {
case BooleanType =>
val boolValue = new TBoolValue
Expand Down Expand Up @@ -232,7 +241,9 @@ object RowSet {
case _ =>
val tStrValue = new TStringValue
if (!row.isNullAt(ordinal)) {
tStrValue.setValue(toHiveString(row.get(ordinal) -> types(ordinal).dataType))
tStrValue.setValue(toHiveString(
row.get(ordinal) -> types(ordinal).dataType,
timeFormatters = timeFormatters))
}
TColumnValue.stringVal(tStrValue)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -105,17 +106,31 @@ object SparkDatasetHelper extends Logging {
val quotedCol = (name: String) => col(quoteIfNeeded(name))

// an udf to call `RowSet.toHiveString` on complex types(struct/array/map) and timestamp type.
// TODO: reuse the timeFormatters on greater scale if possible,
// recreating timeFormatters may cause performance issue, see [KYUUBI#5811]
val toHiveStringUDF = udf[String, Row, String]((row, schemaDDL) => {
val dt = DataType.fromDDL(schemaDDL)
dt match {
case StructType(Array(StructField(_, st: StructType, _, _))) =>
RowSet.toHiveString((row, st), nested = true)
RowSet.toHiveString(
(row, st),
nested = true,
timeFormatters = HiveResult.getTimeFormatters)
case StructType(Array(StructField(_, at: ArrayType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, at), nested = true)
RowSet.toHiveString(
(row.toSeq.head, at),
nested = true,
timeFormatters = HiveResult.getTimeFormatters)
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), nested = true)
RowSet.toHiveString(
(row.toSeq.head, mt),
nested = true,
timeFormatters = HiveResult.getTimeFormatters)
case StructType(Array(StructField(_, tt: TimestampType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, tt), nested = true)
RowSet.toHiveString(
(row.toSeq.head, tt),
nested = true,
timeFormatters = HiveResult.getTimeFormatters)
case _ =>
throw new UnsupportedOperationException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -165,14 +166,18 @@ class RowSetSuite extends KyuubiFunSuite {
dateCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) =>
assert(b === RowSet.toHiveString(Date.valueOf(s"2018-11-${i + 1}") -> DateType))
assert(b === RowSet.toHiveString(
Date.valueOf(s"2018-11-${i + 1}") -> DateType,
timeFormatters = HiveResult.getTimeFormatters))
}

val tsCol = cols.next().getStringVal
tsCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b ===
RowSet.toHiveString(Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType))
RowSet.toHiveString(
Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType,
timeFormatters = HiveResult.getTimeFormatters))
}

val binCol = cols.next().getBinaryVal
Expand All @@ -185,14 +190,16 @@ class RowSetSuite extends KyuubiFunSuite {
arrCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType)))
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType),
timeFormatters = HiveResult.getTimeFormatters))
}

val mapCol = cols.next().getStringVal
mapCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType)))
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType),
timeFormatters = HiveResult.getTimeFormatters))
}

val intervalCol = cols.next().getStringVal
Expand Down Expand Up @@ -241,7 +248,9 @@ class RowSetSuite extends KyuubiFunSuite {
val r8 = iter.next().getColVals
assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]"))
assert(r8.get(13).getStringVal.getValue ===
RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType)))
RowSet.toHiveString(
Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType),
timeFormatters = HiveResult.getTimeFormatters))

val r9 = iter.next().getColVals
assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)
Expand Down