Skip to content

Commit e99cc51

Browse files
Fixing nested WriteSupport and adding tests
1 parent 1dc5ac9 commit e99cc51

File tree

4 files changed

+99
-38
lines changed

4 files changed

+99
-38
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,60 +141,67 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
141141
}
142142

143143
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
144-
schema match {
145-
case t @ ArrayType(_) => writeArray(t, value.asInstanceOf[Row])
146-
case t @ MapType(_, _) => writeMap(t, value.asInstanceOf[Map[Any, Any]])
147-
case t @ StructType(_) => writeStruct(t, value.asInstanceOf[Row])
148-
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
144+
if (value != null && value != Nil) {
145+
schema match {
146+
case t @ ArrayType(_) => writeArray(t, value.asInstanceOf[Row])
147+
case t @ MapType(_, _) => writeMap(t, value.asInstanceOf[Map[Any, Any]])
148+
case t @ StructType(_) => writeStruct(t, value.asInstanceOf[Row])
149+
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
150+
}
149151
}
150152
}
151153

152154
private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = {
153-
schema match {
154-
case StringType => writer.addBinary(
155-
Binary.fromByteArray(
156-
value.asInstanceOf[String].getBytes("utf-8")
155+
if (value != null && value != Nil) {
156+
schema match {
157+
case StringType => writer.addBinary(
158+
Binary.fromByteArray(
159+
value.asInstanceOf[String].getBytes("utf-8")
160+
)
157161
)
158-
)
159-
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
160-
case LongType => writer.addLong(value.asInstanceOf[Long])
161-
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
162-
case FloatType => writer.addFloat(value.asInstanceOf[Float])
163-
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
164-
case _ => sys.error(s"Do not know how to writer $schema to consumer")
162+
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
163+
case LongType => writer.addLong(value.asInstanceOf[Long])
164+
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
165+
case FloatType => writer.addFloat(value.asInstanceOf[Float])
166+
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
167+
case _ => sys.error(s"Do not know how to writer $schema to consumer")
168+
}
165169
}
166170
}
167171

168172
private[parquet] def writeStruct(schema: StructType, struct: Row): Unit = {
169-
val fields = schema.fields.toArray
170-
writer.startGroup()
171-
var i = 0
172-
while(i < fields.size) {
173-
writer.startField(fields(i).name, i)
174-
writeValue(fields(i).dataType, struct(i))
175-
writer.endField(fields(i).name, i)
176-
i = i + 1
173+
if (struct != null && struct != Nil) {
174+
val fields = schema.fields.toArray
175+
writer.startGroup()
176+
var i = 0
177+
while(i < fields.size) {
178+
if (struct(i) != null && struct(i) != Nil) {
179+
writer.startField(fields(i).name, i)
180+
writeValue(fields(i).dataType, struct(i))
181+
writer.endField(fields(i).name, i)
182+
}
183+
i = i + 1
184+
}
185+
writer.endGroup()
177186
}
178-
writer.endGroup()
179187
}
180188

181189
private[parquet] def writeArray(schema: ArrayType, array: Row): Unit = {
182190
val elementType = schema.elementType
183191
writer.startGroup()
184192
if (array.size > 0) {
185193
writer.startField("values", 0)
186-
writer.startGroup()
187194
var i = 0
188195
while(i < array.size) {
189196
writeValue(elementType, array(i))
190197
i = i + 1
191198
}
192-
writer.endGroup()
193199
writer.endField("values", 0)
194200
}
195201
writer.endGroup()
196202
}
197203

204+
// TODO: this does not allow null values! Should these be supported?
198205
private[parquet] def writeMap(schema: MapType, map: Map[_, _]): Unit = {
199206
writer.startGroup()
200207
if (map.size > 0) {

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,13 @@ private[sql] object ParquetTestData {
182182
|optional group data1 {
183183
|repeated group map {
184184
|required binary key;
185-
|optional int32 value;
185+
|required int32 value;
186186
|}
187187
|}
188188
|required group data2 {
189189
|repeated group map {
190190
|required binary key;
191-
|optional group value {
191+
|required group value {
192192
|required int64 payload1;
193193
|optional binary payload2;
194194
|}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,14 @@ private[parquet] object ParquetTypesConverter {
6464
* <ul>
6565
* <li> Primitive types are converter to the corresponding primitive type.</li>
6666
* <li> Group types that have a single field that is itself a group, which has repetition
67-
* level `REPEATED` are treated as follows:<ul>
68-
* <li> If the nested group has name `values` and repetition level `REPEATED`, the
69-
* surrounding group is converted into an [[ArrayType]] with the
70-
* corresponding field type (primitive or complex) as element type.</li>
71-
* <li> If the nested group has name `map`, repetition level `REPEATED` and two fields
72-
* (named `key` and `value`), the surrounding group is converted into a [[MapType]]
73-
* with the corresponding key and value (value possibly complex) types.</li>
67+
* level `REPEATED`, are treated as follows:<ul>
68+
* <li> If the nested group has name `values`, the surrounding group is converted
69+
* into an [[ArrayType]] with the corresponding field type (primitive or
70+
* complex) as element type.</li>
71+
* <li> If the nested group has name `map` and two fields (named `key` and `value`),
72+
* the surrounding group is converted into a [[MapType]]
73+
* with the corresponding key and value (value possibly complex) types.
74+
* Note that we currently assume map values are not nullable.</li>
7475
* <li> Other group types are converted into a [[StructType]] with the corresponding
7576
* field types.</li></ul></li>
7677
* </ul>
@@ -121,15 +122,19 @@ private[parquet] object ParquetTypesConverter {
121122
keyValueGroup.getFieldCount == 2,
122123
"Parquet Map type malformatted: nested group should have 2 (key, value) fields!")
123124
val keyType = toDataType(keyValueGroup.getFields.apply(0))
125+
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
124126
val valueType = toDataType(keyValueGroup.getFields.apply(1))
127+
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
125128
new MapType(keyType, valueType)
126129
}
127130
case _ => {
128131
// Note: the order of these checks is important!
129132
if (correspondsToMap(groupType)) { // MapType
130133
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
131134
val keyType = toDataType(keyValueGroup.getFields.apply(0))
135+
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
132136
val valueType = toDataType(keyValueGroup.getFields.apply(1))
137+
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
133138
new MapType(keyType, valueType)
134139
} else if (correspondsToArray(groupType)) { // ArrayType
135140
val elementType = toDataType(groupType.getFields.apply(0))
@@ -240,13 +245,13 @@ private[parquet] object ParquetTypesConverter {
240245
fromDataType(
241246
keyType,
242247
CatalystConverter.MAP_KEY_SCHEMA_NAME,
243-
false,
248+
nullable = false,
244249
inArray = false)
245250
val parquetValueType =
246251
fromDataType(
247252
valueType,
248253
CatalystConverter.MAP_VALUE_SCHEMA_NAME,
249-
true,
254+
nullable = false,
250255
inArray = false)
251256
ConversionPatterns.mapType(
252257
repetition,

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,55 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
553553
assert(result2(0)(1) === "the answer")
554554
}
555555

556+
test("Writing out Addressbook and reading it back in") {
557+
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
558+
val tmpdir = Utils.createTempDir()
559+
val result = TestSQLContext
560+
.parquetFile(ParquetTestData.testNestedDir1.toString)
561+
.toSchemaRDD
562+
result.saveAsParquetFile(tmpdir.toString)
563+
TestSQLContext
564+
.parquetFile(tmpdir.toString)
565+
.toSchemaRDD
566+
.registerAsTable("tmpcopy")
567+
val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect()
568+
assert(tmpdata.size === 2)
569+
assert(tmpdata(0).size === 2)
570+
assert(tmpdata(0)(0) === "Julien Le Dem")
571+
assert(tmpdata(0)(1) === "Chris Aniszczyk")
572+
assert(tmpdata(1)(0) === "A. Nonymous")
573+
assert(tmpdata(1)(1) === null)
574+
Utils.deleteRecursively(tmpdir)
575+
}
576+
577+
test("Writing out Map and reading it back in") {
578+
implicit def anyToMap(value: Any) = value.asInstanceOf[Map[String, Row]]
579+
val data = TestSQLContext
580+
.parquetFile(ParquetTestData.testNestedDir4.toString)
581+
.toSchemaRDD
582+
val tmpdir = Utils.createTempDir()
583+
data.saveAsParquetFile(tmpdir.toString)
584+
TestSQLContext
585+
.parquetFile(tmpdir.toString)
586+
.toSchemaRDD
587+
.registerAsTable("tmpmapcopy")
588+
val result1 = sql("SELECT data2 FROM tmpmapcopy").collect()
589+
assert(result1.size === 1)
590+
val entry1 = result1(0)(0).getOrElse("7", null)
591+
assert(entry1 != null)
592+
assert(entry1(0) === 42)
593+
assert(entry1(1) === "the answer")
594+
val entry2 = result1(0)(0).getOrElse("8", null)
595+
assert(entry2 != null)
596+
assert(entry2(0) === 49)
597+
assert(entry2(1) === null)
598+
val result2 = sql("SELECT data2[7].payload1, data2[7].payload2 FROM tmpmapcopy").collect()
599+
assert(result2.size === 1)
600+
assert(result2(0)(0) === 42.toLong)
601+
assert(result2(0)(1) === "the answer")
602+
Utils.deleteRecursively(tmpdir)
603+
}
604+
556605
/**
557606
* Creates an empty SchemaRDD backed by a ParquetRelation.
558607
*

0 commit comments

Comments
 (0)