@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
25
25
import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
26
26
import org .apache .spark .sql .execution .FileSourceScanExec
27
27
import org .apache .spark .sql .execution .SparkPlan
28
- import org .apache .spark .sql .types .StructType
28
+ import org .apache .spark .sql .types .{ StructField , StructType }
29
29
30
30
/**
31
31
* A strategy for planning scans over collections of files that might be partitioned or bucketed
@@ -99,11 +99,9 @@ object FileSourceStrategy extends Strategy with Logging {
99
99
.filter(requiredAttributes.contains)
100
100
.filterNot(partitionColumns.contains)
101
101
val outputSchema = if (fsRelation.sqlContext.conf.isParquetNestColumnPruning) {
102
- val requiredColumnsWithNesting = generateRequiredColumnsContainsNesting(
103
- projects, readDataColumns.attrs.map(_.name).toArray)
104
102
val totalSchema = readDataColumns.toStructType
105
- val prunedSchema = StructType (requiredColumnsWithNesting
106
- .map( totalSchema.getFieldRecursively ))
103
+ val prunedSchema = StructType (
104
+ generateStructFieldsContainsNesting(projects, totalSchema))
107
105
// Merge schema in same StructType and merge with filterAttributes
108
106
prunedSchema.fields.map(f => StructType (Array (f))).reduceLeft(_ merge _)
109
107
.merge(filterAttributes.toSeq.toStructType)
@@ -137,55 +135,51 @@ object FileSourceStrategy extends Strategy with Logging {
137
135
case _ => Nil
138
136
}
139
137
140
- private def generateRequiredColumnsContainsNesting (projects : Seq [Expression ],
141
- columns : Array [String ]) : Array [String ] = {
142
- def generateAttributeMap (nestFieldMap : scala.collection.mutable.Map [String , Seq [String ]],
143
- isNestField : Boolean , curString : Option [String ],
144
- node : Expression ) {
138
+ private def generateStructFieldsContainsNesting (projects : Seq [Expression ],
139
+ totalSchema : StructType ) : Seq [StructField ] = {
140
+ def generateStructField (curField : List [String ],
141
+ node : Expression ) : Seq [StructField ] = {
145
142
node match {
146
143
case ai : GetArrayItem =>
147
- // Here we drop the curString for simplify array and map support.
144
+ // Here we drop the previous for simplify array and map support.
148
145
// Same strategy in GetArrayStructFields and GetMapValue
149
- generateAttributeMap(nestFieldMap, isNestField = true , None , ai.child)
150
-
146
+ generateStructField(List .empty[String ], ai.child)
151
147
case asf : GetArrayStructFields =>
152
- generateAttributeMap(nestFieldMap, isNestField = true , None , asf.child)
153
-
148
+ generateStructField(List .empty[String ], asf.child)
154
149
case mv : GetMapValue =>
155
- generateAttributeMap(nestFieldMap, isNestField = true , None , mv.child)
156
-
150
+ generateStructField(List .empty[String ], mv.child)
157
151
case attr : AttributeReference =>
158
- if (isNestField && curString.isDefined) {
159
- val attrStr = attr.name
160
- if (nestFieldMap.contains(attrStr)) {
161
- nestFieldMap(attrStr) = nestFieldMap(attrStr) ++ Seq (attrStr + " ," + curString.get)
162
- } else {
163
- nestFieldMap += (attrStr -> Seq (attrStr + " ," + curString.get))
164
- }
165
- }
152
+ Seq (getFieldRecursively(totalSchema, attr.name :: curField))
166
153
case sf : GetStructField =>
167
- val str = if (curString.isDefined) {
168
- sf.name.get + " ," + curString.get
169
- } else sf.name.get
170
- generateAttributeMap(nestFieldMap, isNestField = true , Option (str), sf.child)
154
+ generateStructField(sf.name.get :: curField, sf.child)
171
155
case _ =>
172
156
if (node.children.nonEmpty) {
173
- node.children.foreach(child => generateAttributeMap(nestFieldMap,
174
- isNestField, curString, child))
157
+ node.children.flatMap(child => generateStructField(curField, child))
158
+ } else {
159
+ Seq .empty[StructField ]
175
160
}
176
161
}
177
162
}
178
163
179
- val nestFieldMap = scala.collection.mutable.Map .empty[String , Seq [String ]]
180
- projects.foreach(p => generateAttributeMap(nestFieldMap, isNestField = false , None , p))
181
- val col_list = columns.toList.flatMap(col => {
182
- if (nestFieldMap.contains(col)) {
183
- nestFieldMap.get(col).get.toList
164
+ def getFieldRecursively (totalSchema : StructType ,
165
+ name : List [String ]): StructField = {
166
+ if (name.length > 1 ) {
167
+ val curField = name.head
168
+ val curFieldType = totalSchema(curField)
169
+ curFieldType.dataType match {
170
+ case st : StructType =>
171
+ val newField = getFieldRecursively(StructType (st.fields), name.drop(1 ))
172
+ StructField (curFieldType.name, StructType (Seq (newField)),
173
+ curFieldType.nullable, curFieldType.metadata)
174
+ case _ =>
175
+ throw new IllegalArgumentException (s """ Field " $curField" is not struct field. """ )
176
+ }
184
177
} else {
185
- List (col )
178
+ totalSchema(name.head )
186
179
}
187
- })
188
- col_list.toArray
180
+ }
181
+
182
+ projects.flatMap(p => generateStructField(List .empty[String ], p))
189
183
}
190
184
191
185
}
0 commit comments