Skip to content

Commit 2d4be52

Browse files
hvanhovellcloud-fan
authored andcommitted
[SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders
### What changes were proposed in this pull request? This PR makes `RowEncoder` produce an `AgnosticEncoder`. The expression generation for these encoders is moved to `ScalaReflection` (this will be moved out in a subsequent PR). The generated serializer and deserializer expressions will slightly change for both schema and type based encoders. These are not semantically different from the old expressions. Concretely the following changes have been introduced: - There is more type validation in maps/arrays/seqs for type based encoders. This should be a positive change, since it disallows users to pass wrong data through erasure hackd. - Array/Seq serialization is a bit more strict. In the old scenario it was possible to pass in sequences/arrays with the wrong type and/or nullability. ### Why are the changes needed? For the Spark Connect Scala Client we also want to be able to use `Row` based results. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? This is a refactoring, existing tests should be sufficient. Closes #39517 from hvanhovell/SPARK-41993. Authored-by: Herman van Hovell <herman@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 9b647e8 commit 2d4be52

File tree

12 files changed

+444
-499
lines changed

12 files changed

+444
-499
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,10 @@ object JavaTypeInference {
423423
case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject)
424424

425425
case c if c == classOf[java.math.BigInteger] =>
426-
createSerializerForJavaBigInteger(inputObject)
426+
createSerializerForBigInteger(inputObject)
427427

428428
case c if c == classOf[java.math.BigDecimal] =>
429-
createSerializerForJavaBigDecimal(inputObject)
429+
createSerializerForBigDecimal(inputObject)
430430

431431
case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
432432
case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 214 additions & 101 deletions
Large diffs are not rendered by default.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,29 @@ object SerializerBuildHelper {
158158
returnNullable = false)
159159
}
160160

161-
def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
161+
def createSerializerForBigDecimal(inputObject: Expression): Expression = {
162+
createSerializerForBigDecimal(inputObject, DecimalType.SYSTEM_DEFAULT)
163+
}
164+
165+
def createSerializerForBigDecimal(inputObject: Expression, dt: DecimalType): Expression = {
162166
CheckOverflow(StaticInvoke(
163167
Decimal.getClass,
164-
DecimalType.SYSTEM_DEFAULT,
168+
dt,
165169
"apply",
166170
inputObject :: Nil,
167-
returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
171+
returnNullable = false), dt, nullOnOverflow)
168172
}
169173

170-
def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = {
171-
createSerializerForJavaBigDecimal(inputObject)
174+
def createSerializerForAnyDecimal(inputObject: Expression, dt: DecimalType): Expression = {
175+
CheckOverflow(StaticInvoke(
176+
Decimal.getClass,
177+
dt,
178+
"fromDecimal",
179+
inputObject :: Nil,
180+
returnNullable = false), dt, nullOnOverflow)
172181
}
173182

174-
def createSerializerForJavaBigInteger(inputObject: Expression): Expression = {
183+
def createSerializerForBigInteger(inputObject: Expression): Expression = {
175184
CheckOverflow(StaticInvoke(
176185
Decimal.getClass,
177186
DecimalType.BigIntDecimal,
@@ -180,10 +189,6 @@ object SerializerBuildHelper {
180189
returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
181190
}
182191

183-
def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
184-
createSerializerForJavaBigInteger(inputObject)
185-
}
186-
187192
def createSerializerForPrimitiveArray(
188193
inputObject: Expression,
189194
dataType: DataType): Expression = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,33 @@
1616
*/
1717
package org.apache.spark.sql.catalyst.encoders
1818

19+
import java.{sql => jsql}
1920
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
2021
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
2122

2223
import scala.reflect.{classTag, ClassTag}
2324

24-
import org.apache.spark.sql.Encoder
25+
import org.apache.spark.sql.{Encoder, Row}
2526
import org.apache.spark.sql.types._
2627
import org.apache.spark.unsafe.types.CalendarInterval
2728

2829
/**
2930
* A non implementation specific encoder. This encoder containers all the information needed
3031
* to generate an implementation specific encoder (e.g. InternalRow <=> Custom Object).
32+
*
33+
* The input of the serialization does not need to match the external type of the encoder. This is
34+
* called lenient serialization. An example of this is lenient date serialization, in this case both
35+
* [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization is never lenient; it
36+
* will always produce instance of the external type.
3137
*/
3238
trait AgnosticEncoder[T] extends Encoder[T] {
3339
def isPrimitive: Boolean
3440
def nullable: Boolean = !isPrimitive
3541
def dataType: DataType
3642
override def schema: StructType = StructType(StructField("value", dataType, nullable) :: Nil)
43+
def lenientSerialization: Boolean = false
3744
}
3845

39-
// TODO check RowEncoder
40-
// TODO check BeanEncoder
4146
object AgnosticEncoders {
4247
case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
4348
extends AgnosticEncoder[Option[E]] {
@@ -46,35 +51,48 @@ object AgnosticEncoders {
4651
override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]])
4752
}
4853

49-
case class ArrayEncoder[E](element: AgnosticEncoder[E])
54+
case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull: Boolean)
5055
extends AgnosticEncoder[Array[E]] {
5156
override def isPrimitive: Boolean = false
52-
override def dataType: DataType = ArrayType(element.dataType, element.nullable)
57+
override def dataType: DataType = ArrayType(element.dataType, containsNull)
5358
override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap
5459
}
5560

56-
case class IterableEncoder[C <: Iterable[E], E](
61+
/**
62+
* Encoder for collections.
63+
*
64+
* This encoder can be lenient for [[Row]] encoders. In that case we allow [[Seq]], primitive
65+
* array (if any), and generic arrays as input.
66+
*/
67+
case class IterableEncoder[C, E](
5768
override val clsTag: ClassTag[C],
58-
element: AgnosticEncoder[E])
69+
element: AgnosticEncoder[E],
70+
containsNull: Boolean,
71+
override val lenientSerialization: Boolean)
5972
extends AgnosticEncoder[C] {
6073
override def isPrimitive: Boolean = false
61-
override val dataType: DataType = ArrayType(element.dataType, element.nullable)
74+
override val dataType: DataType = ArrayType(element.dataType, containsNull)
6275
}
6376

6477
case class MapEncoder[C, K, V](
6578
override val clsTag: ClassTag[C],
6679
keyEncoder: AgnosticEncoder[K],
67-
valueEncoder: AgnosticEncoder[V])
80+
valueEncoder: AgnosticEncoder[V],
81+
valueContainsNull: Boolean)
6882
extends AgnosticEncoder[C] {
6983
override def isPrimitive: Boolean = false
7084
override val dataType: DataType = MapType(
7185
keyEncoder.dataType,
7286
valueEncoder.dataType,
73-
valueEncoder.nullable)
87+
valueContainsNull)
7488
}
7589

76-
case class EncoderField(name: String, enc: AgnosticEncoder[_]) {
77-
def structField: StructField = StructField(name, enc.dataType, enc.nullable)
90+
case class EncoderField(
91+
name: String,
92+
enc: AgnosticEncoder[_],
93+
nullable: Boolean,
94+
metadata: Metadata) {
95+
def structField: StructField = StructField(name, enc.dataType, nullable, metadata)
7896
}
7997

8098
// This supports both Product and DefinedByConstructorParams
@@ -87,6 +105,13 @@ object AgnosticEncoders {
87105
override def dataType: DataType = schema
88106
}
89107

108+
case class RowEncoder(fields: Seq[EncoderField]) extends AgnosticEncoder[Row] {
109+
override def isPrimitive: Boolean = false
110+
override val schema: StructType = StructType(fields.map(_.structField))
111+
override def dataType: DataType = schema
112+
override def clsTag: ClassTag[Row] = classTag[Row]
113+
}
114+
90115
// This will only work for encoding from/to Sparks' InternalRow format.
91116
// It is here for compatibility.
92117
case class UDTEncoder[E >: Null](
@@ -116,39 +141,74 @@ object AgnosticEncoders {
116141
}
117142

118143
// Primitive encoders
119-
case object PrimitiveBooleanEncoder extends LeafEncoder[Boolean](BooleanType)
120-
case object PrimitiveByteEncoder extends LeafEncoder[Byte](ByteType)
121-
case object PrimitiveShortEncoder extends LeafEncoder[Short](ShortType)
122-
case object PrimitiveIntEncoder extends LeafEncoder[Int](IntegerType)
123-
case object PrimitiveLongEncoder extends LeafEncoder[Long](LongType)
124-
case object PrimitiveFloatEncoder extends LeafEncoder[Float](FloatType)
125-
case object PrimitiveDoubleEncoder extends LeafEncoder[Double](DoubleType)
144+
abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType)
145+
extends LeafEncoder[E](dataType)
146+
case object PrimitiveBooleanEncoder extends PrimitiveLeafEncoder[Boolean](BooleanType)
147+
case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType)
148+
case object PrimitiveShortEncoder extends PrimitiveLeafEncoder[Short](ShortType)
149+
case object PrimitiveIntEncoder extends PrimitiveLeafEncoder[Int](IntegerType)
150+
case object PrimitiveLongEncoder extends PrimitiveLeafEncoder[Long](LongType)
151+
case object PrimitiveFloatEncoder extends PrimitiveLeafEncoder[Float](FloatType)
152+
case object PrimitiveDoubleEncoder extends PrimitiveLeafEncoder[Double](DoubleType)
126153

127154
// Primitive wrapper encoders.
128-
case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
129-
case object BoxedBooleanEncoder extends LeafEncoder[java.lang.Boolean](BooleanType)
130-
case object BoxedByteEncoder extends LeafEncoder[java.lang.Byte](ByteType)
131-
case object BoxedShortEncoder extends LeafEncoder[java.lang.Short](ShortType)
132-
case object BoxedIntEncoder extends LeafEncoder[java.lang.Integer](IntegerType)
133-
case object BoxedLongEncoder extends LeafEncoder[java.lang.Long](LongType)
134-
case object BoxedFloatEncoder extends LeafEncoder[java.lang.Float](FloatType)
135-
case object BoxedDoubleEncoder extends LeafEncoder[java.lang.Double](DoubleType)
155+
abstract class BoxedLeafEncoder[E : ClassTag, P](
156+
dataType: DataType,
157+
val primitive: PrimitiveLeafEncoder[P])
158+
extends LeafEncoder[E](dataType)
159+
case object BoxedBooleanEncoder
160+
extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, PrimitiveBooleanEncoder)
161+
case object BoxedByteEncoder
162+
extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, PrimitiveByteEncoder)
163+
case object BoxedShortEncoder
164+
extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, PrimitiveShortEncoder)
165+
case object BoxedIntEncoder
166+
extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, PrimitiveIntEncoder)
167+
case object BoxedLongEncoder
168+
extends BoxedLeafEncoder[java.lang.Long, Long](LongType, PrimitiveLongEncoder)
169+
case object BoxedFloatEncoder
170+
extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, PrimitiveFloatEncoder)
171+
case object BoxedDoubleEncoder
172+
extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, PrimitiveDoubleEncoder)
136173

137174
// Nullable leaf encoders
175+
case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
138176
case object StringEncoder extends LeafEncoder[String](StringType)
139177
case object BinaryEncoder extends LeafEncoder[Array[Byte]](BinaryType)
140-
case object SparkDecimalEncoder extends LeafEncoder[Decimal](DecimalType.SYSTEM_DEFAULT)
141-
case object ScalaDecimalEncoder extends LeafEncoder[BigDecimal](DecimalType.SYSTEM_DEFAULT)
142-
case object JavaDecimalEncoder extends LeafEncoder[JBigDecimal](DecimalType.SYSTEM_DEFAULT)
143178
case object ScalaBigIntEncoder extends LeafEncoder[BigInt](DecimalType.BigIntDecimal)
144179
case object JavaBigIntEncoder extends LeafEncoder[JBigInt](DecimalType.BigIntDecimal)
145180
case object CalendarIntervalEncoder extends LeafEncoder[CalendarInterval](CalendarIntervalType)
146181
case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType())
147182
case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType())
148-
case object DateEncoder extends LeafEncoder[java.sql.Date](DateType)
149-
case object LocalDateEncoder extends LeafEncoder[LocalDate](DateType)
150-
case object TimestampEncoder extends LeafEncoder[java.sql.Timestamp](TimestampType)
151-
case object InstantEncoder extends LeafEncoder[Instant](TimestampType)
183+
case class DateEncoder(override val lenientSerialization: Boolean)
184+
extends LeafEncoder[jsql.Date](DateType)
185+
case class LocalDateEncoder(override val lenientSerialization: Boolean)
186+
extends LeafEncoder[LocalDate](DateType)
187+
case class TimestampEncoder(override val lenientSerialization: Boolean)
188+
extends LeafEncoder[jsql.Timestamp](TimestampType)
189+
case class InstantEncoder(override val lenientSerialization: Boolean)
190+
extends LeafEncoder[Instant](TimestampType)
152191
case object LocalDateTimeEncoder extends LeafEncoder[LocalDateTime](TimestampNTZType)
192+
193+
case class SparkDecimalEncoder(dt: DecimalType) extends LeafEncoder[Decimal](dt)
194+
case class ScalaDecimalEncoder(dt: DecimalType) extends LeafEncoder[BigDecimal](dt)
195+
case class JavaDecimalEncoder(dt: DecimalType, override val lenientSerialization: Boolean)
196+
extends LeafEncoder[JBigDecimal](dt)
197+
198+
val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = false)
199+
val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder = LocalDateEncoder(lenientSerialization = false)
200+
val STRICT_TIMESTAMP_ENCODER: TimestampEncoder = TimestampEncoder(lenientSerialization = false)
201+
val STRICT_INSTANT_ENCODER: InstantEncoder = InstantEncoder(lenientSerialization = false)
202+
val LENIENT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = true)
203+
val LENIENT_LOCAL_DATE_ENCODER: LocalDateEncoder = LocalDateEncoder(lenientSerialization = true)
204+
val LENIENT_TIMESTAMP_ENCODER: TimestampEncoder = TimestampEncoder(lenientSerialization = true)
205+
val LENIENT_INSTANT_ENCODER: InstantEncoder = InstantEncoder(lenientSerialization = true)
206+
207+
val DEFAULT_SPARK_DECIMAL_ENCODER: SparkDecimalEncoder =
208+
SparkDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
209+
val DEFAULT_SCALA_DECIMAL_ENCODER: ScalaDecimalEncoder =
210+
ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
211+
val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
212+
JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false)
153213
}
154214

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ import org.apache.spark.util.Utils
4747
object ExpressionEncoder {
4848

4949
def apply[T : TypeTag](): ExpressionEncoder[T] = {
50-
val enc = ScalaReflection.encoderFor[T]
50+
apply(ScalaReflection.encoderFor[T])
51+
}
52+
53+
def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = {
5154
new ExpressionEncoder[T](
5255
ScalaReflection.serializerFor(enc),
5356
ScalaReflection.deserializerFor(enc),

0 commit comments

Comments
 (0)