Skip to content

Commit e645125

Browse files
steven-aertsgengliangwang
authored andcommitted
[SPARK-30267][SQL] Avro arrays can be of any List
The Deserializer assumed that avro arrays are always of type `GenericData$Array` which is not the case. Assuming they are from java.util.List is safer and fixes a ClassCastException in some avro code. ### What changes were proposed in this pull request? Java.util.List has all the necessary methods and is the base class of GenericData$Array. ### Why are the changes needed? To prevent the following exception in more complex avro objects: ``` java.lang.ClassCastException: java.util.ArrayList cannot be cast to org.apache.avro.generic.GenericData$Array at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$newWriter$19(AvroDeserializer.scala:170) at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$newWriter$19$adapted(AvroDeserializer.scala:169) at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$getRecordWriter$1(AvroDeserializer.scala:314) at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$getRecordWriter$1$adapted(AvroDeserializer.scala:310) at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$getRecordWriter$2(AvroDeserializer.scala:332) at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$getRecordWriter$2$adapted(AvroDeserializer.scala:329) at org.apache.spark.sql.avro.AvroDeserializer.$anonfun$converter$3(AvroDeserializer.scala:56) at org.apache.spark.sql.avro.AvroDeserializer.deserialize(AvroDeserializer.scala:70) ``` ### Does this PR introduce any user-facing change? No ### How was this patch tested? The current tests already test this behavior. In essesence this patch just changes a type case to a more basic type. So I expect no functional impact. Closes #26907 from steven-aerts/spark-30267. Authored-by: Steven Aerts <steven.aerts@gmail.com> Signed-off-by: Gengliang Wang <gengliang.wang@databricks.com>
1 parent d32ed25 commit e645125

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,13 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
167167
case (ARRAY, ArrayType(elementType, containsNull)) =>
168168
val elementWriter = newWriter(avroType.getElementType, elementType, path)
169169
(updater, ordinal, value) =>
170-
val array = value.asInstanceOf[GenericData.Array[Any]]
170+
val array = value.asInstanceOf[java.util.Collection[Any]]
171171
val len = array.size()
172172
val result = createArrayData(elementType, len)
173173
val elementUpdater = new ArrayDataUpdater(result)
174174

175175
var i = 0
176-
while (i < len) {
177-
val element = array.get(i)
176+
for (element <- array.asScala) {
178177
if (element == null) {
179178
if (!containsNull) {
180179
throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " +

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
package org.apache.spark.sql.avro
1919

20+
import java.util
21+
import java.util.Collections
22+
2023
import org.apache.avro.Schema
24+
import org.apache.avro.generic.{GenericData, GenericRecordBuilder}
25+
import org.apache.avro.message.{BinaryMessageDecoder, BinaryMessageEncoder}
2126

2227
import org.apache.spark.{SparkException, SparkFunSuite}
2328
import org.apache.spark.sql.{RandomDataGenerator, Row}
@@ -127,6 +132,26 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
127132
}
128133
}
129134

135+
test("array of nested schema with seed") {
136+
val seed = scala.util.Random.nextLong()
137+
val rand = new scala.util.Random(seed)
138+
val schema = StructType(
139+
StructField("a",
140+
ArrayType(
141+
RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes),
142+
containsNull = false),
143+
nullable = false
144+
) :: Nil
145+
)
146+
147+
withClue(s"Schema: $schema\nseed: $seed") {
148+
val data = RandomDataGenerator.randomRow(rand, schema)
149+
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
150+
val input = Literal.create(converter(data), schema)
151+
roundTripTest(input)
152+
}
153+
}
154+
130155
test("read int as string") {
131156
val data = Literal(1)
132157
val avroTypeJson =
@@ -246,4 +271,46 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
246271
}.getMessage
247272
assert(message == "Cannot convert Catalyst type StringType to Avro type \"long\".")
248273
}
274+
275+
test("avro array can be generic java collection") {
276+
val jsonFormatSchema =
277+
"""
278+
|{ "type": "record",
279+
| "name": "record",
280+
| "fields" : [{
281+
| "name": "array",
282+
| "type": {
283+
| "type": "array",
284+
| "items": ["null", "int"]
285+
| }
286+
| }]
287+
|}
288+
""".stripMargin
289+
val avroSchema = new Schema.Parser().parse(jsonFormatSchema)
290+
val dataType = SchemaConverters.toSqlType(avroSchema).dataType
291+
val deserializer = new AvroDeserializer(avroSchema, dataType)
292+
293+
def checkDeserialization(data: GenericData.Record, expected: Any): Unit = {
294+
assert(checkResult(
295+
expected,
296+
deserializer.deserialize(data),
297+
dataType, exprNullable = false
298+
))
299+
}
300+
301+
def validateDeserialization(array: java.util.Collection[Integer]): Unit = {
302+
val data = new GenericRecordBuilder(avroSchema)
303+
.set("array", array)
304+
.build()
305+
val expected = InternalRow(new GenericArrayData(new util.ArrayList[Any](array)))
306+
checkDeserialization(data, expected)
307+
308+
val reEncoded = new BinaryMessageDecoder[GenericData.Record](new GenericData(), avroSchema)
309+
.decode(new BinaryMessageEncoder(new GenericData(), avroSchema).encode(data))
310+
checkDeserialization(reEncoded, expected)
311+
}
312+
313+
validateDeserialization(Collections.emptySet())
314+
validateDeserialization(util.Arrays.asList(1, null, 3))
315+
}
249316
}

0 commit comments

Comments
 (0)