|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.avro
|
19 | 19 |
|
| 20 | +import java.util |
| 21 | +import java.util.Collections |
| 22 | + |
20 | 23 | import org.apache.avro.Schema
|
| 24 | +import org.apache.avro.generic.{GenericData, GenericRecordBuilder} |
| 25 | +import org.apache.avro.message.{BinaryMessageDecoder, BinaryMessageEncoder} |
21 | 26 |
|
22 | 27 | import org.apache.spark.{SparkException, SparkFunSuite}
|
23 | 28 | import org.apache.spark.sql.{RandomDataGenerator, Row}
|
@@ -127,6 +132,26 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
|
127 | 132 | }
|
128 | 133 | }
|
129 | 134 |
|
| 135 | + test("array of nested schema with seed") { |
| 136 | + val seed = scala.util.Random.nextLong() |
| 137 | + val rand = new scala.util.Random(seed) |
| 138 | + val schema = StructType( |
| 139 | + StructField("a", |
| 140 | + ArrayType( |
| 141 | + RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes), |
| 142 | + containsNull = false), |
| 143 | + nullable = false |
| 144 | + ) :: Nil |
| 145 | + ) |
| 146 | + |
| 147 | + withClue(s"Schema: $schema\nseed: $seed") { |
| 148 | + val data = RandomDataGenerator.randomRow(rand, schema) |
| 149 | + val converter = CatalystTypeConverters.createToCatalystConverter(schema) |
| 150 | + val input = Literal.create(converter(data), schema) |
| 151 | + roundTripTest(input) |
| 152 | + } |
| 153 | + } |
| 154 | + |
130 | 155 | test("read int as string") {
|
131 | 156 | val data = Literal(1)
|
132 | 157 | val avroTypeJson =
|
@@ -246,4 +271,46 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
|
246 | 271 | }.getMessage
|
247 | 272 | assert(message == "Cannot convert Catalyst type StringType to Avro type \"long\".")
|
248 | 273 | }
|
| 274 | + |
| 275 | + test("avro array can be generic java collection") { |
| 276 | + val jsonFormatSchema = |
| 277 | + """ |
| 278 | + |{ "type": "record", |
| 279 | + | "name": "record", |
| 280 | + | "fields" : [{ |
| 281 | + | "name": "array", |
| 282 | + | "type": { |
| 283 | + | "type": "array", |
| 284 | + | "items": ["null", "int"] |
| 285 | + | } |
| 286 | + | }] |
| 287 | + |} |
| 288 | + """.stripMargin |
| 289 | + val avroSchema = new Schema.Parser().parse(jsonFormatSchema) |
| 290 | + val dataType = SchemaConverters.toSqlType(avroSchema).dataType |
| 291 | + val deserializer = new AvroDeserializer(avroSchema, dataType) |
| 292 | + |
| 293 | + def checkDeserialization(data: GenericData.Record, expected: Any): Unit = { |
| 294 | + assert(checkResult( |
| 295 | + expected, |
| 296 | + deserializer.deserialize(data), |
| 297 | + dataType, exprNullable = false |
| 298 | + )) |
| 299 | + } |
| 300 | + |
| 301 | + def validateDeserialization(array: java.util.Collection[Integer]): Unit = { |
| 302 | + val data = new GenericRecordBuilder(avroSchema) |
| 303 | + .set("array", array) |
| 304 | + .build() |
| 305 | + val expected = InternalRow(new GenericArrayData(new util.ArrayList[Any](array))) |
| 306 | + checkDeserialization(data, expected) |
| 307 | + |
| 308 | + val reEncoded = new BinaryMessageDecoder[GenericData.Record](new GenericData(), avroSchema) |
| 309 | + .decode(new BinaryMessageEncoder(new GenericData(), avroSchema).encode(data)) |
| 310 | + checkDeserialization(reEncoded, expected) |
| 311 | + } |
| 312 | + |
| 313 | + validateDeserialization(Collections.emptySet()) |
| 314 | + validateDeserialization(util.Arrays.asList(1, null, 3)) |
| 315 | + } |
249 | 316 | }
|
0 commit comments