Skip to content

Commit 31f8bd0

Browse files
committed
Address some comments
Signed-off-by: Karen Feng <karen.feng@databricks.com>
1 parent 80fa5c3 commit 31f8bd0

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions
2323
* of the name, or the expected nullability).
2424
*/
2525
object AttributeMap {
26+
def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = {
27+
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)))
28+
}
29+
2630
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
2731
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
2832
}

sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ package org.apache.spark.sql.catalyst.expressions
2323
* of the name, or the expected nullability).
2424
*/
2525
object AttributeMap {
26+
def apply[A](kvs: Map[Attribute, A]): AttributeMap[A] = {
27+
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)))
28+
}
29+
2630
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
2731
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
2832
}
@@ -37,6 +41,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
3741

3842
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
3943

44+
override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = get(k).getOrElse(default)
45+
4046
override def contains(k: Attribute): Boolean = get(k).isDefined
4147

4248
override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,7 @@ object NestedColumnAliasing {
146146
val nestedFieldToAlias = attributeToExtractValuesAndAliases.values.flatten.toMap
147147

148148
// A reference attribute can have multiple aliases for nested fields.
149-
val attrToAliases = new AttributeMap(
150-
attributeToExtractValuesAndAliases.map { case (attr, evAliasSeq) =>
151-
attr.exprId -> (attr, evAliasSeq.map(_._2))
152-
}
153-
)
149+
val attrToAliases = AttributeMap(attributeToExtractValuesAndAliases.mapValues(_.map(_._2)))
154150

155151
plan match {
156152
case Project(projectList, child) =>
@@ -245,7 +241,9 @@ object NestedColumnAliasing {
245241
val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
246242
exprList.foreach { e =>
247243
collectRootReferenceAndExtractValue(e).foreach {
248-
case ev: ExtractValue => nestedFieldReferences.append(ev)
244+
case ev: ExtractValue =>
245+
assert(ev.references.size == 1, s"$ev should have one reference")
246+
nestedFieldReferences.append(ev)
249247
case ar: AttributeReference => otherRootReferences.append(ar)
250248
}
251249
}
@@ -306,7 +304,7 @@ object GeneratorNestedColumnAliasing {
306304
// when `nestedSchemaPruningEnabled` is on, nested columns will be pruned further at
307305
// file format readers if it is supported.
308306
case Project(projectList, g: Generate) if (SQLConf.get.nestedPruningOnExpressions ||
309-
SQLConf.get.nestedSchemaPruningEnabled) && canPruneGenerator(g.generator) =>
307+
SQLConf.get.nestedSchemaPruningEnabled) && canPruneGenerator(g.generator) =>
310308
// On top on `Generate`, a `Project` that might have nested column accessors.
311309
// We try to get alias maps for both project list and generator's children expressions.
312310
val attrToExtractValues = NestedColumnAliasing.getAttributeToExtractValues(
@@ -373,9 +371,7 @@ object GeneratorNestedColumnAliasing {
373371
val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)
374372

375373
// Replace nested column accessor with generator output.
376-
val attrExprIdsOnGenerator = attrToExtractValuesOnGenerator.map { case (attr, _) =>
377-
attr.exprId
378-
}.toSet
374+
val attrExprIdsOnGenerator = attrToExtractValuesOnGenerator.keys.map(_.exprId).toSet
379375
val updatedProject = p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
380376
case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
381377
updatedGenerate.output

0 commit comments

Comments
 (0)