Skip to content

Commit 4bef25a

Browse files
committed
Add metrics for SortMergeOuterJoin
1 parent 95ccfc6 commit 4bef25a

File tree

2 files changed

+62
-8
lines changed

2 files changed

+62
-8
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
2525
import org.apache.spark.sql.catalyst.plans.physical._
2626
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
27+
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
2728

2829
/**
2930
* :: DeveloperApi ::
@@ -40,6 +41,11 @@ case class SortMergeOuterJoin(
4041
left: SparkPlan,
4142
right: SparkPlan) extends BinaryNode {
4243

44+
override private[sql] lazy val metrics = Map(
45+
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
46+
"numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
47+
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
48+
4349
override def output: Seq[Attribute] = {
4450
joinType match {
4551
case LeftOuter =>
@@ -108,6 +114,10 @@ case class SortMergeOuterJoin(
108114
}
109115

110116
override def doExecute(): RDD[InternalRow] = {
117+
val numLeftRows = longMetric("numLeftRows")
118+
val numRightRows = longMetric("numRightRows")
119+
val numOutputRows = longMetric("numOutputRows")
120+
111121
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
112122
// An ordering that can be used to compare keys from both sides.
113123
val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
@@ -133,21 +143,27 @@ case class SortMergeOuterJoin(
133143
bufferedKeyGenerator = createRightKeyGenerator(),
134144
keyOrdering,
135145
streamedIter = RowIterator.fromScala(leftIter),
136-
bufferedIter = RowIterator.fromScala(rightIter)
146+
numLeftRows,
147+
bufferedIter = RowIterator.fromScala(rightIter),
148+
numRightRows
137149
)
138150
val rightNullRow = new GenericInternalRow(right.output.length)
139-
new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala
151+
new LeftOuterIterator(
152+
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
140153

141154
case RightOuter =>
142155
val smjScanner = new SortMergeJoinScanner(
143156
streamedKeyGenerator = createRightKeyGenerator(),
144157
bufferedKeyGenerator = createLeftKeyGenerator(),
145158
keyOrdering,
146159
streamedIter = RowIterator.fromScala(rightIter),
147-
bufferedIter = RowIterator.fromScala(leftIter)
160+
numRightRows,
161+
bufferedIter = RowIterator.fromScala(leftIter),
162+
numLeftRows
148163
)
149164
val leftNullRow = new GenericInternalRow(left.output.length)
150-
new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala
165+
new RightOuterIterator(
166+
smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
151167

152168
case x =>
153169
throw new IllegalArgumentException(
@@ -162,7 +178,8 @@ private class LeftOuterIterator(
162178
smjScanner: SortMergeJoinScanner,
163179
rightNullRow: InternalRow,
164180
boundCondition: InternalRow => Boolean,
165-
resultProj: InternalRow => InternalRow
181+
resultProj: InternalRow => InternalRow,
182+
numRows: LongSQLMetric
166183
) extends RowIterator {
167184
private[this] val joinedRow: JoinedRow = new JoinedRow()
168185
private[this] var rightIdx: Int = 0
@@ -198,7 +215,9 @@ private class LeftOuterIterator(
198215
}
199216

200217
override def advanceNext(): Boolean = {
201-
advanceRightUntilBoundConditionSatisfied() || advanceLeft()
218+
val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft()
219+
if (r) numRows += 1
220+
r
202221
}
203222

204223
override def getRow: InternalRow = resultProj(joinedRow)
@@ -208,7 +227,8 @@ private class RightOuterIterator(
208227
smjScanner: SortMergeJoinScanner,
209228
leftNullRow: InternalRow,
210229
boundCondition: InternalRow => Boolean,
211-
resultProj: InternalRow => InternalRow
230+
resultProj: InternalRow => InternalRow,
231+
numRows: LongSQLMetric
212232
) extends RowIterator {
213233
private[this] val joinedRow: JoinedRow = new JoinedRow()
214234
private[this] var leftIdx: Int = 0
@@ -244,7 +264,9 @@ private class RightOuterIterator(
244264
}
245265

246266
override def advanceNext(): Boolean = {
247-
advanceLeftUntilBoundConditionSatisfied() || advanceRight()
267+
val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight()
268+
if (r) numRows += 1
269+
r
248270
}
249271

250272
override def getRow: InternalRow = resultProj(joinedRow)

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
264264
}
265265
}
266266

267+
test("SortMergeOuterJoin metrics") {
268+
// Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
269+
// this test should use the deterministic number of partitions.
270+
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
271+
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
272+
testDataForJoin.registerTempTable("testDataForJoin")
273+
withTempTable("testDataForJoin") {
274+
// Assume the execution plan is
275+
// ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
276+
val df = sqlContext.sql(
277+
"SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a")
278+
testSparkPlanMetrics(df, 1, Map(
279+
1L -> ("SortMergeOuterJoin", Map(
280+
// It's 4 because we only read 3 rows in the first partition and 1 row in the second one
281+
"number of left rows" -> 6L,
282+
"number of right rows" -> 2L,
283+
"number of output rows" -> 8L)))
284+
)
285+
286+
val df2 = sqlContext.sql(
287+
"SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a")
288+
testSparkPlanMetrics(df2, 1, Map(
289+
1L -> ("SortMergeOuterJoin", Map(
290+
// It's 4 because we only read 3 rows in the first partition and 1 row in the second one
291+
"number of left rows" -> 2L,
292+
"number of right rows" -> 6L,
293+
"number of output rows" -> 8L)))
294+
)
295+
}
296+
}
297+
}
298+
267299
test("BroadcastHashJoin metrics") {
268300
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
269301
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")

0 commit comments

Comments
 (0)