Skip to content

Commit 0b412b0

Browse files
committed
Amortize map construction.
1 parent 38e8a99 commit 0b412b0

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,31 +82,19 @@ object BindReferences extends Logging {
8282

8383
def bindReference[A <: Expression](
8484
expression: A,
85-
input: Seq[Attribute],
85+
input: AttributeSeq,
8686
allowFailures: Boolean = false): A = {
87-
val inputArr = input.toArray
88-
val inputToOrdinal = {
89-
val map = new java.util.HashMap[ExprId, Int](inputArr.length * 2)
90-
var index = 0
91-
input.foreach { attr =>
92-
if (!map.containsKey(attr.exprId)) {
93-
map.put(attr.exprId, index)
94-
}
95-
index += 1
96-
}
97-
map
98-
}
9987
expression.transform { case a: AttributeReference =>
10088
attachTree(a, "Binding attribute") {
101-
val ordinal = Option(inputToOrdinal.get(a.exprId)).getOrElse(-1)
89+
val ordinal = input.getOrdinal(a.exprId)
10290
if (ordinal == -1) {
10391
if (allowFailures) {
10492
a
10593
} else {
106-
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
94+
sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
10795
}
10896
} else {
109-
BoundReference(ordinal, a.dataType, inputArr(ordinal).nullable)
97+
BoundReference(ordinal, a.dataType, input(ordinal).nullable)
11098
}
11199
}
112100
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,31 @@ package object expressions {
8686
/**
8787
* Helper functions for working with `Seq[Attribute]`.
8888
*/
89-
implicit class AttributeSeq(attrs: Seq[Attribute]) {
89+
implicit class AttributeSeq(val attrs: Seq[Attribute]) {
9090
/** Creates a StructType with a schema matching this `Seq[Attribute]`. */
9191
def toStructType: StructType = {
9292
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
9393
}
94+
95+
private lazy val inputArr = attrs.toArray
96+
97+
private lazy val inputToOrdinal = {
98+
val map = new java.util.HashMap[ExprId, Int](inputArr.length * 2)
99+
var index = 0
100+
attrs.foreach { attr =>
101+
if (!map.containsKey(attr.exprId)) {
102+
map.put(attr.exprId, index)
103+
}
104+
index += 1
105+
}
106+
map
107+
}
108+
109+
def apply(ordinal: Int): Attribute = inputArr(ordinal)
110+
111+
def getOrdinal(exprId: ExprId): Int = {
112+
Option(inputToOrdinal.get(exprId)).getOrElse(-1)
113+
}
94114
}
95115

96116
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
296296
/**
297297
* All the attributes that are used for this plan.
298298
*/
299-
lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output)
299+
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
300300

301301
private def cleanExpression(e: Expression): Expression = e match {
302302
case a: Alias =>

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ case class HashAggregateExec(
4949

5050
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
5151

52-
override lazy val allAttributes: Seq[Attribute] =
52+
override lazy val allAttributes: AttributeSeq =
5353
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
5454
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
5555

0 commit comments

Comments
 (0)