Skip to content

Commit dd9d932

Browse files
committed
Avoid creating new Iterators
1 parent 589ea26 commit dd9d932

22 files changed

+224
-189
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
2323
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2424
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
26+
import org.apache.spark.sql.metric.SQLMetrics
2627
import org.apache.spark.sql.sources.BaseRelation
2728
import org.apache.spark.sql.types.DataType
2829
import org.apache.spark.sql.{Row, SQLContext}
@@ -99,9 +100,16 @@ private[sql] case class PhysicalRDD(
99100
rdd: RDD[InternalRow],
100101
extraInformation: String) extends LeafNode {
101102

102-
override protected[sql] val trackNumOfRowsEnabled = true
103+
override private[sql] lazy val metrics = Map(
104+
"numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
103105

104-
protected override def doExecute(): RDD[InternalRow] = rdd
106+
protected override def doExecute(): RDD[InternalRow] = {
107+
val numRows = longMetric("numRows")
108+
rdd.map { row =>
109+
numRows += 1
110+
row
111+
}
112+
}
105113

106114
override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]")
107115
}

sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
2121
import org.apache.spark.sql.Row
2222
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
2323
import org.apache.spark.sql.catalyst.expressions.Attribute
24+
import org.apache.spark.sql.metric.SQLMetrics
2425

2526

2627
/**
@@ -30,11 +31,18 @@ private[sql] case class LocalTableScan(
3031
output: Seq[Attribute],
3132
rows: Seq[InternalRow]) extends LeafNode {
3233

33-
override protected[sql] val trackNumOfRowsEnabled = true
34+
override private[sql] lazy val metrics = Map(
35+
"numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
3436

3537
private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
3638

37-
protected override def doExecute(): RDD[InternalRow] = rdd
39+
protected override def doExecute(): RDD[InternalRow] = {
40+
val numRows = longMetric("numRows")
41+
rdd.map { row =>
42+
numRows += 1
43+
row
44+
}
45+
}
3846

3947
override def executeCollect(): Array[Row] = {
4048
val converter = CatalystTypeConverters.createToScalaConverter(schema)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
8080
super.makeCopy(newArgs)
8181
}
8282

83-
/**
84-
* Whether track the number of rows output by this SparkPlan
85-
*/
86-
protected[sql] def trackNumOfRowsEnabled: Boolean = false
87-
88-
private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] =
89-
if (trackNumOfRowsEnabled) {
90-
Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
91-
}
92-
else {
93-
Map.empty
94-
}
95-
9683
/**
9784
* Return all metrics containing metrics of this SparkPlan.
9885
*/
99-
private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics
86+
private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
10087

10188
/**
10289
* Return a IntSQLMetric according to the name.
@@ -156,15 +143,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
156143
}
157144
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
158145
prepare()
159-
if (trackNumOfRowsEnabled) {
160-
val numRows = longMetric("numRows")
161-
doExecute().map { row =>
162-
numRows += 1
163-
row
164-
}
165-
} else {
166-
doExecute()
167-
}
146+
doExecute()
168147
}
169148
}
170149

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,7 @@ case class SortBasedAggregate(
7070
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
7171
val numInputRows = longMetric("numInputRows")
7272
val numOutputRows = longMetric("numOutputRows")
73-
child.execute().mapPartitions { _iter =>
74-
val iter = _iter.map { row =>
75-
numInputRows += 1
76-
row
77-
}
73+
child.execute().mapPartitions { iter =>
7874
// Because the constructor of an aggregation iterator will read at least the first row,
7975
// we need to get the value of iter.hasNext first.
8076
val hasInput = iter.hasNext
@@ -95,17 +91,16 @@ case class SortBasedAggregate(
9591
newProjection _,
9692
child.output,
9793
iter,
98-
outputsUnsafeRows)
94+
outputsUnsafeRows,
95+
numInputRows,
96+
numOutputRows)
9997
if (!hasInput && groupingExpressions.isEmpty) {
10098
// There is no input and there is no grouping expressions.
10199
// We need to output a single row as the output.
102100
numOutputRows += 1
103101
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
104102
} else {
105-
outputIter.map { row =>
106-
numOutputRows += 1
107-
row
108-
}
103+
outputIter
109104
}
110105
}
111106
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
23+
import org.apache.spark.sql.metric.LongSQLMetric
2324
import org.apache.spark.unsafe.KVIterator
2425

2526
/**
@@ -37,7 +38,9 @@ class SortBasedAggregationIterator(
3738
initialInputBufferOffset: Int,
3839
resultExpressions: Seq[NamedExpression],
3940
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
40-
outputsUnsafeRows: Boolean)
41+
outputsUnsafeRows: Boolean,
42+
numInputRows: LongSQLMetric,
43+
numOutputRows: LongSQLMetric)
4144
extends AggregationIterator(
4245
groupingKeyAttributes,
4346
valueAttributes,
@@ -103,6 +106,7 @@ class SortBasedAggregationIterator(
103106
// Get the grouping key.
104107
val groupingKey = inputKVIterator.getKey
105108
val currentRow = inputKVIterator.getValue
109+
numInputRows += 1
106110

107111
// Check if the current row belongs the current input row.
108112
if (currentGroupingKey == groupingKey) {
@@ -137,7 +141,7 @@ class SortBasedAggregationIterator(
137141
val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
138142
// Initialize buffer values for the next group.
139143
initializeBuffer(sortBasedAggregationBuffer)
140-
144+
numOutputRows += 1
141145
outputRow
142146
} else {
143147
// no more result
@@ -151,7 +155,7 @@ class SortBasedAggregationIterator(
151155

152156
nextGroupingKey = inputKVIterator.getKey().copy()
153157
firstRowInNextGroup = inputKVIterator.getValue().copy()
154-
158+
numInputRows += 1
155159
sortedInputHasNewGroup = true
156160
} else {
157161
// This inputIter is empty.
@@ -181,7 +185,9 @@ object SortBasedAggregationIterator {
181185
newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
182186
inputAttributes: Seq[Attribute],
183187
inputIter: Iterator[InternalRow],
184-
outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
188+
outputsUnsafeRows: Boolean,
189+
numInputRows: LongSQLMetric,
190+
numOutputRows: LongSQLMetric): SortBasedAggregationIterator = {
185191
val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
186192
AggregationIterator.unsafeKVIterator(
187193
groupingExprs,
@@ -202,7 +208,9 @@ object SortBasedAggregationIterator {
202208
initialInputBufferOffset,
203209
resultExpressions,
204210
newMutableProjection,
205-
outputsUnsafeRows)
211+
outputsUnsafeRows,
212+
numInputRows,
213+
numOutputRows)
206214
}
207215
// scalastyle:on
208216
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,7 @@ case class TungstenAggregate(
6868
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
6969
val numInputRows = longMetric("numInputRows")
7070
val numOutputRows = longMetric("numOutputRows")
71-
child.execute().mapPartitions { _iter =>
72-
val iter = _iter.map { row =>
73-
numInputRows += 1
74-
row
75-
}
71+
child.execute().mapPartitions { iter =>
7672
val hasInput = iter.hasNext
7773
if (!hasInput && groupingExpressions.nonEmpty) {
7874
// This is a grouped aggregate and the input iterator is empty,
@@ -89,16 +85,15 @@ case class TungstenAggregate(
8985
newMutableProjection,
9086
child.output,
9187
iter,
92-
testFallbackStartsAt)
88+
testFallbackStartsAt,
89+
numInputRows,
90+
numOutputRows)
9391

9492
if (!hasInput && groupingExpressions.isEmpty) {
9593
numOutputRows += 1
9694
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
9795
} else {
98-
aggregationIterator.map { row =>
99-
numOutputRows += 1
100-
row
101-
}
96+
aggregationIterator
10297
}
10398
}
10499
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
27+
import org.apache.spark.sql.metric.LongSQLMetric
2728
import org.apache.spark.sql.types.StructType
2829

2930
/**
@@ -83,7 +84,9 @@ class TungstenAggregationIterator(
8384
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
8485
originalInputAttributes: Seq[Attribute],
8586
inputIter: Iterator[InternalRow],
86-
testFallbackStartsAt: Option[Int])
87+
testFallbackStartsAt: Option[Int],
88+
numInputRows: LongSQLMetric,
89+
numOutputRows: LongSQLMetric)
8790
extends Iterator[UnsafeRow] with Logging {
8891

8992
///////////////////////////////////////////////////////////////////////////
@@ -352,6 +355,7 @@ class TungstenAggregationIterator(
352355
private def processInputs(): Unit = {
353356
while (!sortBased && inputIter.hasNext) {
354357
val newInput = inputIter.next()
358+
numInputRows += 1
355359
val groupingKey = groupProjection.apply(newInput)
356360
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
357361
if (buffer == null) {
@@ -371,6 +375,7 @@ class TungstenAggregationIterator(
371375
var i = 0
372376
while (!sortBased && inputIter.hasNext) {
373377
val newInput = inputIter.next()
378+
numInputRows += 1
374379
val groupingKey = groupProjection.apply(newInput)
375380
val buffer: UnsafeRow = if (i < fallbackStartsAt) {
376381
hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
@@ -439,6 +444,7 @@ class TungstenAggregationIterator(
439444
// Process the rest of input rows.
440445
while (inputIter.hasNext) {
441446
val newInput = inputIter.next()
447+
numInputRows += 1
442448
val groupingKey = groupProjection.apply(newInput)
443449
buffer.copyFrom(initialAggregationBuffer)
444450
processRow(buffer, newInput)
@@ -462,6 +468,7 @@ class TungstenAggregationIterator(
462468
// Insert the rest of input rows.
463469
while (inputIter.hasNext) {
464470
val newInput = inputIter.next()
471+
numInputRows += 1
465472
val groupingKey = groupProjection.apply(newInput)
466473
bufferExtractor(newInput)
467474
externalSorter.insertKV(groupingKey, buffer)
@@ -657,7 +664,7 @@ class TungstenAggregationIterator(
657664
TaskContext.get().internalMetricsToAccumulators(
658665
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
659666
}
660-
667+
numOutputRows += 1
661668
res
662669
} else {
663670
// no more result

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ case class BroadcastHashJoin(
8686
numBuildRows += 1
8787
row.copy()
8888
}.collect()
89-
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
89+
// The following line doesn't run in a job so we cannot track the metric value. However, we
90+
// have already tracked it in the above lines. So here we can use
91+
// `SQLMetrics.nullLongMetric` to ignore it.
92+
val hashed = HashedRelation(
93+
input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
9094
sparkContext.broadcast(hashed)
9195
}
9296
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
@@ -113,13 +117,7 @@ case class BroadcastHashJoin(
113117
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
114118
case _ =>
115119
}
116-
hashJoin(streamedIter.map { row =>
117-
numStreamedRows += 1
118-
row
119-
}, hashedRelation).map { row =>
120-
numOutputRows += 1
121-
row
122-
}
120+
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
123121
}
124122
}
125123
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ case class BroadcastHashOuterJoin(
8989
numBuildRows += 1
9090
row.copy()
9191
}.collect()
92-
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
92+
// The following line doesn't run in a job so we cannot track the metric value. However, we
93+
// have already tracked it in the above lines. So here we can use
94+
// `SQLMetrics.nullLongMetric` to ignore it.
95+
val hashed = HashedRelation(
96+
input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size)
9397
sparkContext.broadcast(hashed)
9498
}
9599
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
@@ -111,11 +115,7 @@ case class BroadcastHashOuterJoin(
111115

112116
val broadcastRelation = Await.result(broadcastFuture, timeout)
113117

114-
streamedPlan.execute().mapPartitions { _streamedIter =>
115-
val streamedIter = _streamedIter.map { row =>
116-
numStreamedRows += 1
117-
row
118-
}
118+
streamedPlan.execute().mapPartitions { streamedIter =>
119119
val joinedRow = new JoinedRow()
120120
val hashTable = broadcastRelation.value
121121
val keyGenerator = streamedKeyGenerator
@@ -131,25 +131,24 @@ case class BroadcastHashOuterJoin(
131131
joinType match {
132132
case LeftOuter =>
133133
streamedIter.flatMap(currentRow => {
134+
numStreamedRows += 1
134135
val rowKey = keyGenerator(currentRow)
135136
joinedRow.withLeft(currentRow)
136-
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj)
137+
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
137138
})
138139

139140
case RightOuter =>
140141
streamedIter.flatMap(currentRow => {
142+
numStreamedRows += 1
141143
val rowKey = keyGenerator(currentRow)
142144
joinedRow.withRight(currentRow)
143-
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj)
145+
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
144146
})
145147

146148
case x =>
147149
throw new IllegalArgumentException(
148150
s"BroadcastHashOuterJoin should not take $x as the JoinType")
149151
}
150-
}.map { row =>
151-
numOutputRows += 1
152-
row
153152
}
154153
}
155154
}

0 commit comments

Comments
 (0)