Skip to content

Commit b539fde

Browse files
First commit for MapType
1 parent a594aed commit b539fde

File tree

4 files changed

+236
-23
lines changed

4 files changed

+236
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.Logging
2626
import org.apache.spark.sql.catalyst.types._
2727
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute}
2828
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
29+
import scala.collection.mutable
2930

3031
private[parquet] object CatalystConverter {
3132
// The type internally used for fields
@@ -55,6 +56,14 @@ private[parquet] object CatalystConverter {
5556
case StructType(fields: Seq[StructField]) => {
5657
new CatalystStructConverter(fields, fieldIndex, parent)
5758
}
59+
case MapType(keyType: DataType, valueType: DataType) => {
60+
new CatalystMapConverter(
61+
Seq(
62+
new FieldType("key", keyType, false),
63+
new FieldType("value", valueType, true)),
64+
fieldIndex,
65+
parent)
66+
}
5867
case ctype: NativeType => {
5968
// note: for some reason matching for StringType fails so use this ugly if instead
6069
if (ctype == StringType) {
@@ -396,6 +405,67 @@ private[parquet] class CatalystStructConverter(
396405
override def getCurrentRecord: Row = throw new UnsupportedOperationException
397406
}
398407

399-
// TODO: add MapConverter
408+
private[parquet] class CatalystMapConverter(
409+
protected[parquet] val schema: Seq[FieldType],
410+
override protected[parquet] val index: Int,
411+
override protected[parquet] val parent: CatalystConverter)
412+
extends GroupConverter with CatalystConverter {
413+
414+
private val map = new mutable.HashMap[Any, Any]()
415+
416+
private val keyValueConverter = new GroupConverter with CatalystConverter {
417+
private var currentKey: Any = null
418+
private var currentValue: Any = null
419+
val keyConverter = CatalystConverter.createConverter(schema(0), 0, this)
420+
val valueConverter = CatalystConverter.createConverter(schema(1), 1, this)
421+
422+
override def getConverter(fieldIndex: Int): Converter = if (fieldIndex == 0) keyConverter else valueConverter
423+
424+
override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue
425+
426+
override def start(): Unit = {
427+
currentKey = null
428+
currentValue = null
429+
}
430+
431+
override protected[parquet] val size: Int = 2
432+
override protected[parquet] val index: Int = 0
433+
override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this
434+
435+
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = fieldIndex match {
436+
case 0 =>
437+
currentKey = value
438+
case 1 =>
439+
currentValue = value
440+
case _ =>
441+
new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex")
442+
}
443+
444+
override protected[parquet] def clearBuffer(): Unit = {}
445+
override def getCurrentRecord: Row = throw new UnsupportedOperationException
446+
}
447+
448+
override protected[parquet] val size: Int = 1
449+
450+
override protected[parquet] def clearBuffer(): Unit = {}
451+
452+
override def start(): Unit = {
453+
map.clear()
454+
}
455+
456+
// TODO: think about reusing the buffer
457+
override def end(): Unit = {
458+
assert(!isRootConverter)
459+
parent.updateField(index, map)
460+
}
461+
462+
override def getConverter(fieldIndex: Int): Converter = keyValueConverter
463+
464+
override def getCurrentRecord: Row = throw new UnsupportedOperationException
465+
466+
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit =
467+
throw new UnsupportedOperationException
468+
}
469+
400470

401471

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,32 @@ private[sql] object ParquetTestData {
167167
|}
168168
""".stripMargin
169169

170+
val testNestedSchema4 =
171+
"""
172+
|message TestNested4 {
173+
|required int32 x;
174+
|optional group data1 {
175+
|repeated group map {
176+
|required binary key;
177+
|optional int32 value;
178+
|}
179+
|}
180+
|required group data2 {
181+
|repeated group map {
182+
|required int32 key;
183+
|optional group value {
184+
|required int64 payload1;
185+
|optional binary payload2;
186+
|}
187+
|}
188+
|}
189+
|}
190+
""".stripMargin
191+
170192
val testNestedDir1 = Utils.createTempDir()
171193
val testNestedDir2 = Utils.createTempDir()
172194
val testNestedDir3 = Utils.createTempDir()
195+
val testNestedDir4 = Utils.createTempDir()
173196

174197
lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString)
175198
lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString)
@@ -327,6 +350,37 @@ private[sql] object ParquetTestData {
327350
writer.close()
328351
}
329352

353+
def writeNestedFile4() {
354+
testNestedDir4.delete()
355+
val path: Path = testNestedDir4
356+
val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema4)
357+
358+
val r1 = new SimpleGroup(schema)
359+
r1.add(0, 7)
360+
val map1 = r1.addGroup(1)
361+
val keyValue1 = map1.addGroup(0)
362+
keyValue1.add(0, "key1")
363+
keyValue1.add(1, 1)
364+
val keyValue2 = map1.addGroup(0)
365+
keyValue2.add(0, "key2")
366+
keyValue2.add(1, 2)
367+
val map2 = r1.addGroup(2)
368+
val keyValue3 = map2.addGroup(0)
369+
keyValue3.add(0, 7)
370+
val valueGroup1 = keyValue3.addGroup(1)
371+
valueGroup1.add(0, 42.toLong)
372+
valueGroup1.add(1, "the answer")
373+
val keyValue4 = map2.addGroup(0)
374+
keyValue4.add(0, 8)
375+
val valueGroup2 = keyValue4.addGroup(1)
376+
valueGroup2.add(0, 49.toLong)
377+
378+
val writeSupport = new TestGroupWriteSupport(schema)
379+
val writer = new ParquetWriter[Group](path, writeSupport)
380+
writer.write(r1)
381+
writer.close()
382+
}
383+
330384
def readNestedFile(path: File, schemaString: String): Unit = {
331385
val configuration = new Configuration()
332386
val fs: FileSystem = path.getFileSystem(configuration)

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,17 @@ private[parquet] object ParquetTypesConverter {
6363
* Note that we apply the following conversion rules:
6464
* <ul>
6565
* <li> Primitive types are converter to the corresponding primitive type.</li>
66-
* <li> Group types that have a single field with repetition `REPEATED` or themselves
67-
* have repetition level `REPEATED` are converted to an [[ArrayType]] with the
68-
* corresponding field type (possibly primitive) as element type.</li>
66+
* <li> Group types that have a single field that is itself a group, which has repetition
67+
* level `REPEATED` and two fields (named `key` and `value`), are converted to
68+
* a [[MapType]] with the corresponding key and value (value possibly complex)
69+
* as element type.</li>
6970
* <li> Other group types are converted as follows:<ul>
70-
* <li> If they have a single field, they are converted into a [[StructType]] with
71+
* <li> Group types that have a single field with repetition `REPEATED` or themselves
72+
* have repetition level `REPEATED` are converted to an [[ArrayType]] with the
73+
* corresponding field type (possibly primitive) as element type.</li>
74+
* <li> Other groups with a single field are converted into a [[StructType]] with
7175
* the corresponding field type.</li>
72-
* <li> If they have more than one field and repetition level `REPEATED` they are
76+
* <li> If groups have more than one field and repetition level `REPEATED` they are
7377
* converted into an [[ArrayType]] with the corresponding [[StructType]] as complex
7478
* element type.</li>
7579
* <li> Otherwise they are converted into a [[StructType]] with the corresponding
@@ -82,16 +86,33 @@ private[parquet] object ParquetTypesConverter {
8286
* @return The corresponding Catalyst type.
8387
*/
8488
def toDataType(parquetType: ParquetType): DataType = {
89+
def correspondsToMap(groupType: ParquetGroupType): Boolean = {
90+
if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) {
91+
false
92+
} else {
93+
// This mostly follows the convention in ``parquet.schema.ConversionPatterns``
94+
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
95+
keyValueGroup.getRepetition == Repetition.REPEATED &&
96+
keyValueGroup.getName == "map" &&
97+
keyValueGroup.getFields.apply(0).getName == "key" &&
98+
keyValueGroup.getFields.apply(1).getName == "value"
99+
}
100+
}
101+
def correspondsToArray(groupType: ParquetGroupType): Boolean = {
102+
groupType.getFieldCount == 1 &&
103+
(groupType.getFields.apply(0).getRepetition == Repetition.REPEATED ||
104+
groupType.getRepetition == Repetition.REPEATED)
105+
}
106+
85107
if (parquetType.isPrimitive) {
86108
toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName)
87-
}
88-
else {
109+
} else {
89110
val groupType = parquetType.asGroupType()
90111
parquetType.getOriginalType match {
91112
// if the schema was constructed programmatically there may be hints how to convert
92113
// it inside the metadata via the OriginalType field
93114
case ParquetOriginalType.LIST => { // TODO: check enums!
94-
val fields = groupType.getFields.map {
115+
val fields = groupType.getFields.map {
95116
field => new StructField(
96117
field.getName,
97118
toDataType(field),
@@ -103,16 +124,29 @@ private[parquet] object ParquetTypesConverter {
103124
new ArrayType(StructType(fields))
104125
}
105126
}
127+
case ParquetOriginalType.MAP => {
128+
assert(
129+
!groupType.getFields.apply(0).isPrimitive,
130+
"Parquet Map type malformatted: expected nested group for map!")
131+
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
132+
assert(
133+
keyValueGroup.getFieldCount == 2,
134+
"Parquet Map type malformatted: nested group should have 2 (key, value) fields!")
135+
val keyType = toDataType(keyValueGroup.getFields.apply(0))
136+
val valueType = toDataType(keyValueGroup.getFields.apply(1))
137+
new MapType(keyType, valueType)
138+
}
106139
case _ => {
107-
// everything else nested becomes a Struct, unless it has a single repeated field
108-
// in which case it becomes an array (this should correspond to the inverse operation of
109-
// parquet.schema.ConversionPatterns.listType)
110-
if (groupType.getFieldCount == 1 &&
111-
(groupType.getFields.apply(0).getRepetition == Repetition.REPEATED ||
112-
groupType.getRepetition == Repetition.REPEATED)) {
140+
// Note: the order of these checks is important!
141+
if (correspondsToMap(groupType)) { // MapType
142+
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
143+
val keyType = toDataType(keyValueGroup.getFields.apply(0))
144+
val valueType = toDataType(keyValueGroup.getFields.apply(1))
145+
new MapType(keyType, valueType)
146+
} else if (correspondsToArray(groupType)) { // ArrayType
113147
val elementType = toDataType(groupType.getFields.apply(0))
114148
new ArrayType(elementType)
115-
} else {
149+
} else { // everything else: StructType
116150
val fields = groupType
117151
.getFields
118152
.map(ptype => new StructField(
@@ -164,7 +198,10 @@ private[parquet] object ParquetTypesConverter {
164198
* <ul>
165199
* <li> Primitive types are converted into Parquet's primitive types.</li>
166200
* <li> [[org.apache.spark.sql.catalyst.types.StructType]]s are converted
167-
* into Parquet's `GroupType` with the corresponding field types.</li>
201+
* into Parquet's `GroupType` with the corresponding field types.</li>
202+
* <li> [[org.apache.spark.sql.catalyst.types.MapType]]s are converted
203+
* into a nested (2-level) Parquet `GroupType` with two fields: a key type and
204+
* a value type. The nested group has repetition level `REPEATED`.</li>
168205
* <li> [[org.apache.spark.sql.catalyst.types.ArrayType]]s are handled as follows:<ul>
169206
* <li> If their element is complex, that is of type
170207
* [[org.apache.spark.sql.catalyst.types.StructType]], they are converted
@@ -174,18 +211,18 @@ private[parquet] object ParquetTypesConverter {
174211
* that is also a list but has only a single field of the type corresponding to
175212
* the element type.</li></ul></li>
176213
* </ul>
177-
* Parquet's repetition level is set according to the following rule:
214+
* Parquet's repetition level is generally set according to the following rule:
178215
* <ul>
179-
* <li> If the call to `fromDataType` is recursive inside an enclosing `ArrayType`, then
180-
* the repetition level is set to `REPEATED`.</li>
216+
* <li> If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or
217+
* `MapType`, then the repetition level is set to `REPEATED`.</li>
181218
* <li> Otherwise, if the attribute whose type is converted is `nullable`, the Parquet
182219
* type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.</li>
183220
* </ul>
184-
* The single expection to this rule is an [[org.apache.spark.sql.catalyst.types.ArrayType]]
221+
* The single exception to this rule is an [[org.apache.spark.sql.catalyst.types.ArrayType]]
185222
* that contains a [[org.apache.spark.sql.catalyst.types.StructType]], whose repetition level
186223
* is always set to `REPEATED`.
187224
*
188-
@param ctype The type to convert.
225+
* @param ctype The type to convert.
189226
* @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]]
190227
* whose type is converted
191228
* @param nullable When true indicates that the attribute is nullable
@@ -239,6 +276,13 @@ private[parquet] object ParquetTypesConverter {
239276
}
240277
new ParquetGroupType(repetition, name, fields)
241278
}
279+
case MapType(keyType, valueType) => {
280+
ConversionPatterns.mapType(
281+
repetition,
282+
name,
283+
fromDataType(keyType, "key", false, inArray = false),
284+
fromDataType(valueType, "value", true, inArray = false))
285+
}
242286
case _ => sys.error(s"Unsupported datatype $ctype")
243287
}
244288
}

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
7474
ParquetTestData.writeFilterFile()
7575
ParquetTestData.writeNestedFile1()
7676
ParquetTestData.writeNestedFile2()
77+
ParquetTestData.writeNestedFile3()
78+
ParquetTestData.writeNestedFile4()
7779
testRDD = parquetFile(ParquetTestData.testDir.toString)
7880
testRDD.registerAsTable("testsource")
7981
parquetFile(ParquetTestData.testFilterDir.toString)
@@ -85,6 +87,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
8587
Utils.deleteRecursively(ParquetTestData.testFilterDir)
8688
Utils.deleteRecursively(ParquetTestData.testNestedDir1)
8789
Utils.deleteRecursively(ParquetTestData.testNestedDir2)
90+
Utils.deleteRecursively(ParquetTestData.testNestedDir3)
91+
Utils.deleteRecursively(ParquetTestData.testNestedDir4)
8892
// here we should also unregister the table??
8993
}
9094

@@ -495,7 +499,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
495499

496500
test("nested structs") {
497501
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
498-
ParquetTestData.writeNestedFile3()
499502
val data = TestSQLContext
500503
.parquetFile(ParquetTestData.testNestedDir3.toString)
501504
.toSchemaRDD
@@ -514,6 +517,48 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
514517
assert(result3(0)(0) === false)
515518
}
516519

520+
test("simple map") {
521+
implicit def anyToMap(value: Any) = value.asInstanceOf[collection.mutable.HashMap[String, Int]]
522+
val data = TestSQLContext
523+
.parquetFile(ParquetTestData.testNestedDir4.toString)
524+
.toSchemaRDD
525+
data.registerAsTable("mapTable")
526+
val result1 = sql("SELECT data1 FROM mapTable").collect()
527+
assert(result1.size === 1)
528+
assert(result1(0)(0).toMap.getOrElse("key1", 0) === 1)
529+
assert(result1(0)(0).toMap.getOrElse("key2", 0) === 2)
530+
}
531+
532+
test("map with struct values") {
533+
//implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
534+
implicit def anyToMap(value: Any) = value.asInstanceOf[collection.mutable.HashMap[Int, Row]]
535+
//val data = TestSQLContext
536+
// .parquetFile(ParquetTestData.testNestedDir4.toString)
537+
// .toSchemaRDD
538+
val data = TestSQLContext
539+
.parquetFile(ParquetTestData.testNestedDir4.toString)
540+
.toSchemaRDD
541+
data.registerAsTable("mapTable")
542+
543+
/*ParquetTestData.readNestedFile(
544+
ParquetTestData.testNestedDir4,
545+
ParquetTestData.testNestedSchema4)
546+
val result = TestSQLContext
547+
.parquetFile(ParquetTestData.testNestedDir4.toString)
548+
.toSchemaRDD
549+
.collect()*/
550+
val result1 = sql("SELECT data2 FROM mapTable").collect()
551+
assert(result1.size === 1)
552+
val entry1 = result1(0)(0).getOrElse(7, null)
553+
assert(entry1 != null)
554+
assert(entry1(0) === 42)
555+
assert(entry1(1) === "the answer")
556+
val entry2 = result1(0)(0).getOrElse(8, null)
557+
assert(entry2 != null)
558+
assert(entry2(0) === 49)
559+
assert(entry2(1) === null)
560+
}
561+
517562
/**
518563
* Creates an empty SchemaRDD backed by a ParquetRelation.
519564
*

0 commit comments

Comments
 (0)