Skip to content

Commit 1c34877

Browse files
committed
Get the nested fields not modifying the column names
1 parent 46c2474 commit 1c34877

File tree

4 files changed

+46
-82
lines changed

4 files changed

+46
-82
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -279,29 +279,6 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
279279
StructType(fields.filter(f => names.contains(f.name)))
280280
}
281281

282-
/**
283-
* Extracts the [[StructField]] with the given name recursively.
284-
*
285-
* @throws IllegalArgumentException if the parent field's type is not StructType
286-
*/
287-
def getFieldRecursively(name: String): StructField = {
288-
if (name.contains(',')) {
289-
val curFieldStr = name.split(",", 2)(0)
290-
val nextFieldStr = name.split(",", 2)(1)
291-
val curField = this.apply(curFieldStr)
292-
curField.dataType match {
293-
case st: StructType =>
294-
val newField = StructType(st.fields).getFieldRecursively(nextFieldStr)
295-
StructField(curField.name, StructType(Seq(newField)),
296-
curField.nullable, curField.metadata)
297-
case _ =>
298-
throw new IllegalArgumentException(s"""Field "$curFieldStr" is not struct field.""")
299-
}
300-
} else {
301-
this.apply(name)
302-
}
303-
}
304-
305282
/**
306283
* Returns the index of a given field.
307284
*

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2525
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2626
import org.apache.spark.sql.execution.FileSourceScanExec
2727
import org.apache.spark.sql.execution.SparkPlan
28-
import org.apache.spark.sql.types.StructType
28+
import org.apache.spark.sql.types.{StructField, StructType}
2929

3030
/**
3131
* A strategy for planning scans over collections of files that might be partitioned or bucketed
@@ -99,11 +99,9 @@ object FileSourceStrategy extends Strategy with Logging {
9999
.filter(requiredAttributes.contains)
100100
.filterNot(partitionColumns.contains)
101101
val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning) {
102-
val requiredColumnsWithNesting = generateRequiredColumnsContainsNesting(
103-
projects, readDataColumns.attrs.map(_.name).toArray)
104102
val totalSchema = readDataColumns.toStructType
105-
val prunedSchema = StructType(requiredColumnsWithNesting
106-
.map(totalSchema.getFieldRecursively))
103+
val prunedSchema = StructType(
104+
generateStructFieldsContainsNesting(projects, totalSchema))
107105
// Merge schema in same StructType and merge with filterAttributes
108106
prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _)
109107
.merge(filterAttributes.toSeq.toStructType)
@@ -137,55 +135,51 @@ object FileSourceStrategy extends Strategy with Logging {
137135
case _ => Nil
138136
}
139137

140-
private def generateRequiredColumnsContainsNesting(projects: Seq[Expression],
141-
columns: Array[String]) : Array[String] = {
142-
def generateAttributeMap(nestFieldMap: scala.collection.mutable.Map[String, Seq[String]],
143-
isNestField: Boolean, curString: Option[String],
144-
node: Expression) {
138+
private def generateStructFieldsContainsNesting(projects: Seq[Expression],
139+
totalSchema: StructType) : Seq[StructField] = {
140+
def generateStructField(curField: List[String],
141+
node: Expression) : Seq[StructField] = {
145142
node match {
146143
case ai: GetArrayItem =>
147-
// Here we drop the curString for simplify array and map support.
144+
// Here we drop the previous for simplify array and map support.
148145
// Same strategy in GetArrayStructFields and GetMapValue
149-
generateAttributeMap(nestFieldMap, isNestField = true, None, ai.child)
150-
146+
generateStructField(List.empty[String], ai.child)
151147
case asf: GetArrayStructFields =>
152-
generateAttributeMap(nestFieldMap, isNestField = true, None, asf.child)
153-
148+
generateStructField(List.empty[String], asf.child)
154149
case mv: GetMapValue =>
155-
generateAttributeMap(nestFieldMap, isNestField = true, None, mv.child)
156-
150+
generateStructField(List.empty[String], mv.child)
157151
case attr: AttributeReference =>
158-
if (isNestField && curString.isDefined) {
159-
val attrStr = attr.name
160-
if (nestFieldMap.contains(attrStr)) {
161-
nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq(attrStr + "," + curString.get)
162-
} else {
163-
nestFieldMap += (attrStr -> Seq(attrStr + "," + curString.get))
164-
}
165-
}
152+
Seq(getFieldRecursively(totalSchema, attr.name :: curField))
166153
case sf: GetStructField =>
167-
val str = if (curString.isDefined) {
168-
sf.name.get + "," + curString.get
169-
} else sf.name.get
170-
generateAttributeMap(nestFieldMap, isNestField = true, Option(str), sf.child)
154+
generateStructField(sf.name.get :: curField, sf.child)
171155
case _ =>
172156
if (node.children.nonEmpty) {
173-
node.children.foreach(child => generateAttributeMap(nestFieldMap,
174-
isNestField, curString, child))
157+
node.children.flatMap(child => generateStructField(curField, child))
158+
} else {
159+
Seq.empty[StructField]
175160
}
176161
}
177162
}
178163

179-
val nestFieldMap = scala.collection.mutable.Map.empty[String, Seq[String]]
180-
projects.foreach(p => generateAttributeMap(nestFieldMap, isNestField = false, None, p))
181-
val col_list = columns.toList.flatMap(col => {
182-
if (nestFieldMap.contains(col)) {
183-
nestFieldMap.get(col).get.toList
164+
def getFieldRecursively(totalSchema: StructType,
165+
name: List[String]): StructField = {
166+
if (name.length > 1) {
167+
val curField = name.head
168+
val curFieldType = totalSchema(curField)
169+
curFieldType.dataType match {
170+
case st: StructType =>
171+
val newField = getFieldRecursively(StructType(st.fields), name.drop(1))
172+
StructField(curFieldType.name, StructType(Seq(newField)),
173+
curFieldType.nullable, curFieldType.metadata)
174+
case _ =>
175+
throw new IllegalArgumentException(s"""Field "$curField" is not struct field.""")
176+
}
184177
} else {
185-
List(col)
178+
totalSchema(name.head)
186179
}
187-
})
188-
col_list.toArray
180+
}
181+
182+
projects.flatMap(p => generateStructField(List.empty[String], p))
189183
}
190184

191185
}
Binary file not shown.

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -574,36 +574,29 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
574574
test("SPARK-4502 parquet nested fields pruning") {
575575
// Schema of "test-data/nested-array-struct.parquet":
576576
// root
577-
// |-- primitive: integer (nullable = true)
578-
// |-- myComplex: array (nullable = true)
579-
// | |-- element: struct (containsNull = true)
580-
// | | |-- id: integer (nullable = true)
581-
// | | |-- repeatedMessage: array (nullable = true)
582-
// | | | |-- element: struct (containsNull = true)
583-
// | | | | |-- someId: integer (nullable = true)
584-
val df = readResourceParquetFile("test-data/nested-array-struct.parquet")
577+
// |-- col: struct (nullable = true)
578+
// | |-- s1: struct (nullable = true)
579+
// | | |-- s1_1: long (nullable = true)
580+
// | | |-- s1_2: long (nullable = true)
581+
// | |-- str: string (nullable = true)
582+
// |-- num: long (nullable = true)
583+
// |-- str: string (nullable = true)
584+
val df = readResourceParquetFile("test-data/nested-struct.snappy.parquet")
585585
df.createOrReplaceTempView("tmp_table")
586586
// normal test
587-
val query1 = "select primitive,myComplex[0].id from tmp_table"
587+
val query1 = "select num,col.s1.s1_1 from tmp_table"
588588
val result1 = sql(query1)
589589
withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") {
590590
checkAnswer(sql(query1), result1)
591591
}
592-
// test for array in struct
593-
val query2 = "select primitive,myComplex[0].repeatedMessage[0].someId from tmp_table"
592+
// test for same struct meta merge
593+
// col.s1.s1_1 and col.str should merge
594+
// like col.[s1.s1_1, str] before pass to parquet
595+
val query2 = "select col.s1.s1_1,col.str from tmp_table"
594596
val result2 = sql(query2)
595597
withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") {
596598
checkAnswer(sql(query2), result2)
597599
}
598-
// test for same struct meta merge
599-
// myComplex.id and myComplex.repeatedMessage.someId should merge
600-
// like myComplex.[id, repeatedMessage.someId] before pass to parquet
601-
val query3 = "select myComplex[0].id, myComplex[0].repeatedMessage[0].someId" +
602-
" from tmp_table"
603-
val result3 = sql(query3)
604-
withSQLConf(SQLConf.PARQUET_NEST_COLUMN_PRUNING.key -> "true") {
605-
checkAnswer(sql(query3), result3)
606-
}
607600

608601
spark.sessionState.catalog.dropTable(
609602
TableIdentifier("tmp_table"), ignoreIfNotExists = true, purge = false)

0 commit comments

Comments
 (0)