Skip to content

Commit e377463

Browse files
richardc-dbilicmarkodb
authored andcommitted
[SPARK-48833][SQL][VARIANT] Support variant in InMemoryTableScan
### What changes were proposed in this pull request? adds support for variant type in `InMemoryTableScan`, or `df.cache()` by supporting writing variant values to an inmemory buffer. ### Why are the changes needed? prior to this PR, calling `df.cache()` on a df that has a variant would fail because `InMemoryTableScan` does not support reading variant types. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added UTs ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#47252 from richardc-db/variant_dfcache_support. Authored-by: Richard Chen <r.chen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 2e98dfa commit e377463

File tree

6 files changed

+139
-4
lines changed

6 files changed

+139
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
2828
import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
2929
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
3030
import org.apache.spark.sql.types._
31-
import org.apache.spark.unsafe.types.CalendarInterval
31+
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
3232

3333
/**
3434
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -111,6 +111,10 @@ private[columnar] class IntervalColumnAccessor(buffer: ByteBuffer)
111111
extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL)
112112
with NullableColumnAccessor
113113

114+
private[columnar] class VariantColumnAccessor(buffer: ByteBuffer)
115+
extends BasicColumnAccessor[VariantVal](buffer, VARIANT)
116+
with NullableColumnAccessor
117+
114118
private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
115119
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
116120

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BI
131131
private[columnar]
132132
class IntervalColumnBuilder extends ComplexColumnBuilder(new IntervalColumnStats, CALENDAR_INTERVAL)
133133

134+
private[columnar]
135+
class VariantColumnBuilder extends ComplexColumnBuilder(new VariantColumnStats, VARIANT)
136+
134137
private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType)
135138
extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))
136139

@@ -189,6 +192,7 @@ private[columnar] object ColumnBuilder {
189192
case s: StringType => new StringColumnBuilder(s)
190193
case BinaryType => new BinaryColumnBuilder
191194
case CalendarIntervalType => new IntervalColumnBuilder
195+
case VariantType => new VariantColumnBuilder
192196
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
193197
new CompactDecimalColumnBuilder(dt)
194198
case dt: DecimalType => new DecimalColumnBuilder(dt)

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,21 @@ private[columnar] final class BinaryColumnStats extends ColumnStats {
297297
Array[Any](null, null, nullCount, count, sizeInBytes)
298298
}
299299

300+
private[columnar] final class VariantColumnStats extends ColumnStats {
301+
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
302+
if (!row.isNullAt(ordinal)) {
303+
val size = VARIANT.actualSize(row, ordinal)
304+
sizeInBytes += size
305+
count += 1
306+
} else {
307+
gatherNullStats()
308+
}
309+
}
310+
311+
override def collectedStatistics: Array[Any] =
312+
Array[Any](null, null, nullCount, count, sizeInBytes)
313+
}
314+
300315
private[columnar] final class IntervalColumnStats extends ColumnStats {
301316
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
302317
if (!row.isNullAt(ordinal)) {

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ import scala.annotation.tailrec
2424

2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions._
27-
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, PhysicalStructType}
27+
import org.apache.spark.sql.catalyst.types._
2828
import org.apache.spark.sql.errors.ExecutionErrors
2929
import org.apache.spark.sql.types._
3030
import org.apache.spark.unsafe.Platform
31-
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
31+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
3232

3333

3434
/**
@@ -815,6 +815,45 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval]
815815
}
816816
}
817817

818+
/**
819+
* Used to append/extract Java VariantVals into/from the underlying [[ByteBuffer]] of a column.
820+
*
821+
* Variants are encoded in `append` as:
822+
* | value size | metadata size | value binary | metadata binary |
823+
* and are only expected to be decoded in `extract`.
824+
*/
825+
private[columnar] object VARIANT
826+
extends ColumnType[VariantVal] with DirectCopyColumnType[VariantVal] {
827+
override def dataType: PhysicalDataType = PhysicalVariantType
828+
829+
/** Chosen to match the default size set in `VariantType`. */
830+
override def defaultSize: Int = 2048
831+
832+
override def getField(row: InternalRow, ordinal: Int): VariantVal = row.getVariant(ordinal)
833+
834+
override def setField(row: InternalRow, ordinal: Int, value: VariantVal): Unit =
835+
row.update(ordinal, value)
836+
837+
override def append(v: VariantVal, buffer: ByteBuffer): Unit = {
838+
val valueSize = v.getValue().length
839+
val metadataSize = v.getMetadata().length
840+
ByteBufferHelper.putInt(buffer, valueSize)
841+
ByteBufferHelper.putInt(buffer, metadataSize)
842+
ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getValue()), buffer, valueSize)
843+
ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getMetadata()), buffer, metadataSize)
844+
}
845+
846+
override def extract(buffer: ByteBuffer): VariantVal = {
847+
val valueSize = ByteBufferHelper.getInt(buffer)
848+
val metadataSize = ByteBufferHelper.getInt(buffer)
849+
val valueBuffer = ByteBuffer.allocate(valueSize)
850+
ByteBufferHelper.copyMemory(buffer, valueBuffer, valueSize)
851+
val metadataBuffer = ByteBuffer.allocate(metadataSize)
852+
ByteBufferHelper.copyMemory(buffer, metadataBuffer, metadataSize)
853+
new VariantVal(valueBuffer.array(), metadataBuffer.array())
854+
}
855+
}
856+
818857
private[columnar] object ColumnType {
819858
@tailrec
820859
def apply(dataType: DataType): ColumnType[_] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
8989
case _: StringType => classOf[StringColumnAccessor].getName
9090
case BinaryType => classOf[BinaryColumnAccessor].getName
9191
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
92+
case VariantType => classOf[VariantColumnAccessor].getName
9293
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
9394
classOf[CompactDecimalColumnAccessor].getName
9495
case dt: DecimalType => classOf[DecimalColumnAccessor].getName
@@ -101,7 +102,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
101102
val createCode = dt match {
102103
case t if CodeGenerator.isPrimitiveType(dt) =>
103104
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
104-
case NullType | BinaryType | CalendarIntervalType =>
105+
case NullType | BinaryType | CalendarIntervalType | VariantType =>
105106
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
106107
case other =>
107108
s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),

sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,78 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
639639
}
640640
}
641641

642+
test("variant in a cached row-based df") {
643+
val query = """select
644+
parse_json(format_string('{\"a\": %s}', id)) v,
645+
cast(null as variant) as null_v,
646+
case when id % 2 = 0 then parse_json(cast(id as string)) else null end as some_null
647+
from range(0, 10)"""
648+
val df = spark.sql(query)
649+
df.cache()
650+
651+
val expected = spark.sql(query)
652+
checkAnswer(df, expected.collect())
653+
}
654+
655+
test("struct of variant in a cached row-based df") {
656+
val query = """select named_struct(
657+
'v', parse_json(format_string('{\"a\": %s}', id)),
658+
'null_v', cast(null as variant),
659+
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
660+
) v
661+
from range(0, 10)"""
662+
val df = spark.sql(query)
663+
df.cache()
664+
665+
val expected = spark.sql(query)
666+
checkAnswer(df, expected.collect())
667+
}
668+
669+
test("array of variant in a cached row-based df") {
670+
val query = """select array(
671+
parse_json(cast(id as string)),
672+
parse_json(format_string('{\"a\": %s}', id)),
673+
null,
674+
case when id % 2 = 0 then parse_json(cast(id as string)) else null end) v
675+
from range(0, 10)"""
676+
val df = spark.sql(query)
677+
df.cache()
678+
679+
val expected = spark.sql(query)
680+
checkAnswer(df, expected.collect())
681+
}
682+
683+
test("map of variant in a cached row-based df") {
684+
val query = """select map(
685+
'v', parse_json(format_string('{\"a\": %s}', id)),
686+
'null_v', cast(null as variant),
687+
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
688+
) v
689+
from range(0, 10)"""
690+
val df = spark.sql(query)
691+
df.cache()
692+
693+
val expected = spark.sql(query)
694+
checkAnswer(df, expected.collect())
695+
}
696+
697+
test("variant in a cached column-based df") {
698+
withTable("t") {
699+
val query = """select named_struct(
700+
'v', parse_json(format_string('{\"a\": %s}', id)),
701+
'null_v', cast(null as variant),
702+
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
703+
) v
704+
from range(0, 10)"""
705+
spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t")
706+
val df = spark.sql("select * from t")
707+
df.cache()
708+
709+
val expected = spark.sql(query)
710+
checkAnswer(df, expected.collect())
711+
}
712+
}
713+
642714
test("variant_get size") {
643715
val largeKey = "x" * 1000
644716
val df = Seq(s"""{ "$largeKey": {"a" : 1 },

0 commit comments

Comments
 (0)