@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424import org .apache .spark .sql .catalyst .plans .{JoinType , LeftOuter , RightOuter }
2525import org .apache .spark .sql .catalyst .plans .physical ._
2626import org .apache .spark .sql .execution .{BinaryNode , RowIterator , SparkPlan }
27+ import org .apache .spark .sql .execution .metric .{LongSQLMetric , SQLMetrics }
2728
2829/**
2930 * :: DeveloperApi ::
@@ -40,6 +41,11 @@ case class SortMergeOuterJoin(
4041 left : SparkPlan ,
4142 right : SparkPlan ) extends BinaryNode {
4243
44+ override private [sql] lazy val metrics = Map (
45+ " numLeftRows" -> SQLMetrics .createLongMetric(sparkContext, " number of left rows" ),
46+ " numRightRows" -> SQLMetrics .createLongMetric(sparkContext, " number of right rows" ),
47+ " numOutputRows" -> SQLMetrics .createLongMetric(sparkContext, " number of output rows" ))
48+
4349 override def output : Seq [Attribute ] = {
4450 joinType match {
4551 case LeftOuter =>
@@ -108,6 +114,10 @@ case class SortMergeOuterJoin(
108114 }
109115
110116 override def doExecute (): RDD [InternalRow ] = {
117+ val numLeftRows = longMetric(" numLeftRows" )
118+ val numRightRows = longMetric(" numRightRows" )
119+ val numOutputRows = longMetric(" numOutputRows" )
120+
111121 left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
112122 // An ordering that can be used to compare keys from both sides.
113123 val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
@@ -133,21 +143,27 @@ case class SortMergeOuterJoin(
133143 bufferedKeyGenerator = createRightKeyGenerator(),
134144 keyOrdering,
135145 streamedIter = RowIterator .fromScala(leftIter),
136- bufferedIter = RowIterator .fromScala(rightIter)
146+ numLeftRows,
147+ bufferedIter = RowIterator .fromScala(rightIter),
148+ numRightRows
137149 )
138150 val rightNullRow = new GenericInternalRow (right.output.length)
139- new LeftOuterIterator (smjScanner, rightNullRow, boundCondition, resultProj).toScala
151+ new LeftOuterIterator (
152+ smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
140153
141154 case RightOuter =>
142155 val smjScanner = new SortMergeJoinScanner (
143156 streamedKeyGenerator = createRightKeyGenerator(),
144157 bufferedKeyGenerator = createLeftKeyGenerator(),
145158 keyOrdering,
146159 streamedIter = RowIterator .fromScala(rightIter),
147- bufferedIter = RowIterator .fromScala(leftIter)
160+ numRightRows,
161+ bufferedIter = RowIterator .fromScala(leftIter),
162+ numLeftRows
148163 )
149164 val leftNullRow = new GenericInternalRow (left.output.length)
150- new RightOuterIterator (smjScanner, leftNullRow, boundCondition, resultProj).toScala
165+ new RightOuterIterator (
166+ smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
151167
152168 case x =>
153169 throw new IllegalArgumentException (
@@ -162,7 +178,8 @@ private class LeftOuterIterator(
162178 smjScanner : SortMergeJoinScanner ,
163179 rightNullRow : InternalRow ,
164180 boundCondition : InternalRow => Boolean ,
165- resultProj : InternalRow => InternalRow
181+ resultProj : InternalRow => InternalRow ,
182+ numRows : LongSQLMetric
166183 ) extends RowIterator {
167184 private [this ] val joinedRow : JoinedRow = new JoinedRow ()
168185 private [this ] var rightIdx : Int = 0
@@ -198,7 +215,9 @@ private class LeftOuterIterator(
198215 }
199216
200217 override def advanceNext (): Boolean = {
201- advanceRightUntilBoundConditionSatisfied() || advanceLeft()
218+ val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft()
219+ if (r) numRows += 1
220+ r
202221 }
203222
204223 override def getRow : InternalRow = resultProj(joinedRow)
@@ -208,7 +227,8 @@ private class RightOuterIterator(
208227 smjScanner : SortMergeJoinScanner ,
209228 leftNullRow : InternalRow ,
210229 boundCondition : InternalRow => Boolean ,
211- resultProj : InternalRow => InternalRow
230+ resultProj : InternalRow => InternalRow ,
231+ numRows : LongSQLMetric
212232 ) extends RowIterator {
213233 private [this ] val joinedRow : JoinedRow = new JoinedRow ()
214234 private [this ] var leftIdx : Int = 0
@@ -244,7 +264,9 @@ private class RightOuterIterator(
244264 }
245265
246266 override def advanceNext (): Boolean = {
247- advanceLeftUntilBoundConditionSatisfied() || advanceRight()
267+ val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight()
268+ if (r) numRows += 1
269+ r
248270 }
249271
250272 override def getRow : InternalRow = resultProj(joinedRow)
0 commit comments