Skip to content

Commit 5363783

Browse files
committed
[SPARK-15764][SQL] Replace N^2 loop in BindReferences
BindReferences contains a n^2 loop which causes performance issues when operating over large schemas: to determine the ordinal of an attribute reference, we perform a linear scan over the `input` array. Because input can sometimes be a `List`, the call to `input(ordinal).nullable` can also be O(n). Instead of performing a linear scan, we can convert the input into an array and build a hash map to map from expression ids to ordinals. The greater up-front cost of the map construction is offset by the fact that an expression can contain multiple attribute references, so the cost of the map construction is amortized across a number of lookups. Perf. benchmarks to follow. /cc ericl Author: Josh Rosen <joshrosen@databricks.com> Closes #13505 from JoshRosen/bind-references-improvement. (cherry picked from commit 0b8d694) Signed-off-by: Josh Rosen <joshrosen@databricks.com>
1 parent d07bce4 commit 5363783

File tree

6 files changed

+40
-15
lines changed

6 files changed

+40
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ object AttributeMap {
2626
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
2727
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
2828
}
29-
30-
/** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */
31-
def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex)
32-
33-
/** Given a schema, constructs a map from ordinal to Attribute. */
34-
def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] =
35-
schema.zipWithIndex.map { case (a, i) => i -> a }.toMap
3629
}
3730

3831
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,16 @@ object BindReferences extends Logging {
8282

8383
def bindReference[A <: Expression](
8484
expression: A,
85-
input: Seq[Attribute],
85+
input: AttributeSeq,
8686
allowFailures: Boolean = false): A = {
8787
expression.transform { case a: AttributeReference =>
8888
attachTree(a, "Binding attribute") {
89-
val ordinal = input.indexWhere(_.exprId == a.exprId)
89+
val ordinal = input.indexOf(a.exprId)
9090
if (ordinal == -1) {
9191
if (allowFailures) {
9292
a
9393
} else {
94-
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
94+
sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
9595
}
9696
} else {
9797
BoundReference(ordinal, a.dataType, input(ordinal).nullable)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20+
import com.google.common.collect.Maps
21+
2022
import org.apache.spark.sql.catalyst.expressions._
2123
import org.apache.spark.sql.types.{StructField, StructType}
2224

@@ -86,11 +88,41 @@ package object expressions {
8688
/**
8789
* Helper functions for working with `Seq[Attribute]`.
8890
*/
89-
implicit class AttributeSeq(attrs: Seq[Attribute]) {
91+
implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable {
9092
/** Creates a StructType with a schema matching this `Seq[Attribute]`. */
9193
def toStructType: StructType = {
9294
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
9395
}
96+
97+
// It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when
98+
// accessing attributes by their ordinals. To avoid this performance penalty, convert the input
99+
// to an array.
100+
@transient private lazy val attrsArray = attrs.toArray
101+
102+
@transient private lazy val exprIdToOrdinal = {
103+
val arr = attrsArray
104+
val map = Maps.newHashMapWithExpectedSize[ExprId, Int](arr.length)
105+
// Iterate over the array in reverse order so that the final map value is the first attribute
106+
// with a given expression id.
107+
var index = arr.length - 1
108+
while (index >= 0) {
109+
map.put(arr(index).exprId, index)
110+
index -= 1
111+
}
112+
map
113+
}
114+
115+
/**
116+
* Returns the attribute at the given index.
117+
*/
118+
def apply(ordinal: Int): Attribute = attrsArray(ordinal)
119+
120+
/**
121+
* Returns the index of first attribute with a matching expression id, or -1 if no match exists.
122+
*/
123+
def indexOf(exprId: ExprId): Int = {
124+
Option(exprIdToOrdinal.get(exprId)).getOrElse(-1)
125+
}
94126
}
95127

96128
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
296296
/**
297297
* All the attributes that are used for this plan.
298298
*/
299-
lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output)
299+
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
300300

301301
private def cleanExpression(e: Expression): Expression = e match {
302302
case a: Alias =>

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ case class HashAggregateExec(
4949

5050
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
5151

52-
override lazy val allAttributes: Seq[Attribute] =
52+
override lazy val allAttributes: AttributeSeq =
5353
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
5454
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
5555

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ private[sql] case class InMemoryTableScanExec(
310310
// within the map Partitions closure.
311311
val schema = relation.partitionStatistics.schema
312312
val schemaIndex = schema.zipWithIndex
313-
val relOutput = relation.output
313+
val relOutput: AttributeSeq = relation.output
314314
val buffers = relation.cachedColumnBuffers
315315

316316
buffers.mapPartitionsInternal { cachedBatchIterator =>
@@ -321,7 +321,7 @@ private[sql] case class InMemoryTableScanExec(
321321
// Find the ordinals and data types of the requested columns.
322322
val (requestedColumnIndices, requestedColumnDataTypes) =
323323
attributes.map { a =>
324-
relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
324+
relOutput.indexOf(a.exprId) -> a.dataType
325325
}.unzip
326326

327327
// Do partition batch pruning if enabled

0 commit comments

Comments
 (0)