Skip to content

[SPARK-15764][SQL] Replace N^2 loop in BindReferences #13505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}

/** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */
def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex)

/** Given a schema, constructs a map from ordinal to Attribute. */
def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] =
schema.zipWithIndex.map { case (a, i) => i -> a }.toMap
}

class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was vaguely-related yet unused code that I stumbled across while looking for similar occurrences of this pattern, so I decided to remove it.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ object BindReferences extends Logging {

def bindReference[A <: Expression](
expression: A,
input: Seq[Attribute],
input: AttributeSeq,
allowFailures: Boolean = false): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
val ordinal = input.indexOf(a.exprId)
if (ordinal == -1) {
if (allowFailures) {
a
} else {
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
}
} else {
BoundReference(ordinal, a.dataType, input(ordinal).nullable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst

import com.google.common.collect.Maps

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructField, StructType}

Expand Down Expand Up @@ -86,11 +88,41 @@ package object expressions {
/**
* Helper functions for working with `Seq[Attribute]`.
*/
implicit class AttributeSeq(attrs: Seq[Attribute]) {
implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable {
/** Creates a StructType with a schema matching this `Seq[Attribute]`. */
def toStructType: StructType = {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}

// It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when
// accessing attributes by their ordinals. To avoid this performance penalty, convert the input
// to an array.
@transient private lazy val attrsArray = attrs.toArray

@transient private lazy val exprIdToOrdinal = {
val arr = attrsArray
val map = Maps.newHashMapWithExpectedSize[ExprId, Int](arr.length)
// Iterate over the array in reverse order so that the final map value is the first attribute
// with a given expression id.
var index = arr.length - 1
while (index >= 0) {
map.put(arr(index).exprId, index)
index -= 1
}
map
}

/**
* Returns the attribute at the given index.
*/
def apply(ordinal: Int): Attribute = attrsArray(ordinal)

/**
* Returns the index of first attribute with a matching expression id, or -1 if no match exists.
*/
def indexOf(exprId: ExprId): Int = {
Option(exprIdToOrdinal.get(exprId)).getOrElse(-1)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
/**
* All the attributes that are used for this plan.
*/
lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output)
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericl and I found another layer of polynomial looping: in QueryPlan.cleanArgs we take every expression in the query plan and bind its references against allAttributes, which can be huge. If we turn this into an AttributeSeq once and build the map inside of that wrapper then we amortize that cost and remove this expensive loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably construct the AttributeSeq outside of the loop in the various projection operators, too, although that doesn't appear to be as serious a bottleneck yet.


private def cleanExpression(e: Expression): Expression = e match {
case a: Alias =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ case class HashAggregateExec(

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

override lazy val allAttributes: Seq[Attribute] =
override lazy val allAttributes: AttributeSeq =
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private[sql] case class InMemoryTableScanExec(
// within the map Partitions closure.
val schema = relation.partitionStatistics.schema
val schemaIndex = schema.zipWithIndex
val relOutput = relation.output
val relOutput: AttributeSeq = relation.output
val buffers = relation.cachedColumnBuffers

buffers.mapPartitionsInternal { cachedBatchIterator =>
Expand All @@ -321,7 +321,7 @@ private[sql] case class InMemoryTableScanExec(
// Find the ordinals and data types of the requested columns.
val (requestedColumnIndices, requestedColumnDataTypes) =
attributes.map { a =>
relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
relOutput.indexOf(a.exprId) -> a.dataType
}.unzip

// Do partition batch pruning if enabled
Expand Down