Skip to content

Commit d650075

Browse files
JkSelfdongjoon-hyun
authored andcommitted
[SPARK-26316][SPARK-21052][BRANCH-2.4] Revert hash join metrics in that causes performance degradation
## What changes were proposed in this pull request? revert spark 21052 in spark 2.4 because of the discussion in [PR23269](#23269) ## How was this patch tested? N/A Closes #23318 from JkSelf/branch-2.4-revert21052. Authored-by: jiake <ke.a.jia@intel.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 869bfc9 commit d650075

File tree

5 files changed

+6
-159
lines changed

5 files changed

+6
-159
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ case class BroadcastHashJoinExec(
4848
extends BinaryExecNode with HashJoin with CodegenSupport {
4949

5050
override lazy val metrics = Map(
51-
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
52-
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
51+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
5352

5453
override def requiredChildDistribution: Seq[Distribution] = {
5554
val mode = HashedRelationBroadcastMode(buildKeys)
@@ -63,13 +62,12 @@ case class BroadcastHashJoinExec(
6362

6463
protected override def doExecute(): RDD[InternalRow] = {
6564
val numOutputRows = longMetric("numOutputRows")
66-
val avgHashProbe = longMetric("avgHashProbe")
6765

6866
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
6967
streamedPlan.execute().mapPartitions { streamedIter =>
7068
val hashed = broadcastRelation.value.asReadOnlyCopy()
7169
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
72-
join(streamedIter, hashed, numOutputRows, avgHashProbe)
70+
join(streamedIter, hashed, numOutputRows)
7371
}
7472
}
7573

@@ -111,23 +109,6 @@ case class BroadcastHashJoinExec(
111109
}
112110
}
113111

114-
/**
115-
* Returns the codes used to add a task completion listener to update avg hash probe
116-
* at the end of the task.
117-
*/
118-
private def genTaskListener(avgHashProbe: String, relationTerm: String): String = {
119-
val listenerClass = classOf[TaskCompletionListener].getName
120-
val taskContextClass = classOf[TaskContext].getName
121-
s"""
122-
| $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() {
123-
| @Override
124-
| public void onTaskCompletion($taskContextClass context) {
125-
| $avgHashProbe.set($relationTerm.getAverageProbesPerLookup());
126-
| }
127-
| });
128-
""".stripMargin
129-
}
130-
131112
/**
132113
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
133114
*/
@@ -137,15 +118,11 @@ case class BroadcastHashJoinExec(
137118
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
138119
val clsName = broadcastRelation.value.getClass.getName
139120

140-
// At the end of the task, we update the avg hash probe.
141-
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
142-
143121
// Inline mutable state since not many join operations in a task
144122
val relationTerm = ctx.addMutableState(clsName, "relation",
145123
v => s"""
146124
| $v = (($clsName) $broadcast.value()).asReadOnlyCopy();
147125
| incPeakExecutionMemory($v.estimatedSize());
148-
| ${genTaskListener(avgHashProbe, v)}
149126
""".stripMargin, forceInline = true)
150127
(broadcastRelation, relationTerm)
151128
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ trait HashJoin {
194194
protected def join(
195195
streamedIter: Iterator[InternalRow],
196196
hashed: HashedRelation,
197-
numOutputRows: SQLMetric,
198-
avgHashProbe: SQLMetric): Iterator[InternalRow] = {
197+
numOutputRows: SQLMetric): Iterator[InternalRow] = {
199198

200199
val joinedIter = joinType match {
201200
case _: InnerLike =>
@@ -213,10 +212,6 @@ trait HashJoin {
213212
s"BroadcastHashJoin should not take $x as the JoinType")
214213
}
215214

216-
// At the end of the task, we update the avg hash probe.
217-
TaskContext.get().addTaskCompletionListener[Unit](_ =>
218-
avgHashProbe.set(hashed.getAverageProbesPerLookup))
219-
220215
val resultProj = createResultProjection
221216
joinedIter.map { r =>
222217
numOutputRows += 1

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
8181
*/
8282
def close(): Unit
8383

84-
/**
85-
* Returns the average number of probes per key lookup.
86-
*/
87-
def getAverageProbesPerLookup: Double
8884
}
8985

9086
private[execution] object HashedRelation {
@@ -281,7 +277,6 @@ private[joins] class UnsafeHashedRelation(
281277
read(() => in.readInt(), () => in.readLong(), in.readBytes)
282278
}
283279

284-
override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup
285280
}
286281

287282
private[joins] object UnsafeHashedRelation {
@@ -395,10 +390,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
395390
// The number of unique keys.
396391
private var numKeys = 0L
397392

398-
// Tracking average number of probes per key lookup.
399-
private var numKeyLookups = 0L
400-
private var numProbes = 0L
401-
402393
// needed by serializer
403394
def this() = {
404395
this(
@@ -483,8 +474,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
483474
*/
484475
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
485476
if (isDense) {
486-
numKeyLookups += 1
487-
numProbes += 1
488477
if (key >= minKey && key <= maxKey) {
489478
val value = array((key - minKey).toInt)
490479
if (value > 0) {
@@ -493,14 +482,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
493482
}
494483
} else {
495484
var pos = firstSlot(key)
496-
numKeyLookups += 1
497-
numProbes += 1
498485
while (array(pos + 1) != 0) {
499486
if (array(pos) == key) {
500487
return getRow(array(pos + 1), resultRow)
501488
}
502489
pos = nextSlot(pos)
503-
numProbes += 1
504490
}
505491
}
506492
null
@@ -528,8 +514,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
528514
*/
529515
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
530516
if (isDense) {
531-
numKeyLookups += 1
532-
numProbes += 1
533517
if (key >= minKey && key <= maxKey) {
534518
val value = array((key - minKey).toInt)
535519
if (value > 0) {
@@ -538,14 +522,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
538522
}
539523
} else {
540524
var pos = firstSlot(key)
541-
numKeyLookups += 1
542-
numProbes += 1
543525
while (array(pos + 1) != 0) {
544526
if (array(pos) == key) {
545527
return valueIter(array(pos + 1), resultRow)
546528
}
547529
pos = nextSlot(pos)
548-
numProbes += 1
549530
}
550531
}
551532
null
@@ -585,11 +566,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
585566
private def updateIndex(key: Long, address: Long): Unit = {
586567
var pos = firstSlot(key)
587568
assert(numKeys < array.length / 2)
588-
numKeyLookups += 1
589-
numProbes += 1
590569
while (array(pos) != key && array(pos + 1) != 0) {
591570
pos = nextSlot(pos)
592-
numProbes += 1
593571
}
594572
if (array(pos + 1) == 0) {
595573
// this is the first value for this key, put the address in array.
@@ -721,8 +699,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
721699
writeLong(maxKey)
722700
writeLong(numKeys)
723701
writeLong(numValues)
724-
writeLong(numKeyLookups)
725-
writeLong(numProbes)
726702

727703
writeLong(array.length)
728704
writeLongArray(writeBuffer, array, array.length)
@@ -764,8 +740,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
764740
maxKey = readLong()
765741
numKeys = readLong()
766742
numValues = readLong()
767-
numKeyLookups = readLong()
768-
numProbes = readLong()
769743

770744
val length = readLong().toInt
771745
mask = length - 2
@@ -784,10 +758,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
784758
read(() => in.readBoolean(), () => in.readLong(), in.readBytes)
785759
}
786760

787-
/**
788-
* Returns the average number of probes per key lookup.
789-
*/
790-
def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups
791761
}
792762

793763
private[joins] class LongHashedRelation(
@@ -840,7 +810,6 @@ private[joins] class LongHashedRelation(
840810
map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
841811
}
842812

843-
override def getAverageProbesPerLookup: Double = map.getAverageProbesPerLookup
844813
}
845814

846815
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ case class ShuffledHashJoinExec(
4242
override lazy val metrics = Map(
4343
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
4444
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
45-
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"),
46-
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
45+
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
4746

4847
override def requiredChildDistribution: Seq[Distribution] =
4948
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
@@ -63,10 +62,9 @@ case class ShuffledHashJoinExec(
6362

6463
protected override def doExecute(): RDD[InternalRow] = {
6564
val numOutputRows = longMetric("numOutputRows")
66-
val avgHashProbe = longMetric("avgHashProbe")
6765
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
6866
val hashed = buildHashedRelation(buildIter)
69-
join(streamIter, hashed, numOutputRows, avgHashProbe)
67+
join(streamIter, hashed, numOutputRows)
7068
}
7169
}
7270
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 1 addition & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -231,50 +231,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
231231
)
232232
}
233233

234-
test("BroadcastHashJoin metrics: track avg probe") {
235-
// The executed plan looks like:
236-
// Project [a#210, b#211, b#221]
237-
// +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight
238-
// :- Project [_1#207 AS a#210, _2#208 AS b#211]
239-
// : +- Filter isnotnull(_1#207)
240-
// : +- LocalTableScan [_1#207, _2#208]
241-
// +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true]))
242-
// +- Project [_1#217 AS a#220, _2#218 AS b#221]
243-
// +- Filter isnotnull(_1#217)
244-
// +- LocalTableScan [_1#217, _2#218]
245-
//
246-
// Assume the execution plan with node id is
247-
// WholeStageCodegen disabled:
248-
// Project(nodeId = 0)
249-
// BroadcastHashJoin(nodeId = 1)
250-
// ...(ignored)
251-
//
252-
// WholeStageCodegen enabled:
253-
// WholeStageCodegen(nodeId = 0)
254-
// Project(nodeId = 1)
255-
// BroadcastHashJoin(nodeId = 2)
256-
// Project(nodeId = 3)
257-
// Filter(nodeId = 4)
258-
// ...(ignored)
259-
Seq(true, false).foreach { enableWholeStage =>
260-
val df1 = generateRandomBytesDF()
261-
val df2 = generateRandomBytesDF()
262-
val df = df1.join(broadcast(df2), "a")
263-
val nodeIds = if (enableWholeStage) {
264-
Set(2L)
265-
} else {
266-
Set(1L)
267-
}
268-
val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get
269-
nodeIds.foreach { nodeId =>
270-
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
271-
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
272-
assert(probe.toDouble > 1.0)
273-
}
274-
}
275-
}
276-
}
277-
278234
test("ShuffledHashJoin metrics") {
279235
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40",
280236
"spark.sql.shuffle.partitions" -> "2",
@@ -287,59 +243,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
287243
val metrics = getSparkPlanMetrics(df, 1, Set(1L))
288244
testSparkPlanMetrics(df, 1, Map(
289245
1L -> (("ShuffledHashJoin", Map(
290-
"number of output rows" -> 2L,
291-
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))))
246+
"number of output rows" -> 2L))))
292247
)
293248
}
294249
}
295250

296-
test("ShuffledHashJoin metrics: track avg probe") {
297-
// The executed plan looks like:
298-
// Project [a#308, b#309, b#319]
299-
// +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight
300-
// :- Exchange hashpartitioning(a#308, 2)
301-
// : +- Project [_1#305 AS a#308, _2#306 AS b#309]
302-
// : +- Filter isnotnull(_1#305)
303-
// : +- LocalTableScan [_1#305, _2#306]
304-
// +- Exchange hashpartitioning(a#318, 2)
305-
// +- Project [_1#315 AS a#318, _2#316 AS b#319]
306-
// +- Filter isnotnull(_1#315)
307-
// +- LocalTableScan [_1#315, _2#316]
308-
//
309-
// Assume the execution plan with node id is
310-
// WholeStageCodegen disabled:
311-
// Project(nodeId = 0)
312-
// ShuffledHashJoin(nodeId = 1)
313-
// ...(ignored)
314-
//
315-
// WholeStageCodegen enabled:
316-
// WholeStageCodegen(nodeId = 0)
317-
// Project(nodeId = 1)
318-
// ShuffledHashJoin(nodeId = 2)
319-
// ...(ignored)
320-
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000",
321-
"spark.sql.shuffle.partitions" -> "2",
322-
"spark.sql.join.preferSortMergeJoin" -> "false") {
323-
Seq(true, false).foreach { enableWholeStage =>
324-
val df1 = generateRandomBytesDF(65535 * 5)
325-
val df2 = generateRandomBytesDF(65535)
326-
val df = df1.join(df2, "a")
327-
val nodeIds = if (enableWholeStage) {
328-
Set(2L)
329-
} else {
330-
Set(1L)
331-
}
332-
val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
333-
nodeIds.foreach { nodeId =>
334-
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
335-
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
336-
assert(probe.toDouble > 1.0)
337-
}
338-
}
339-
}
340-
}
341-
}
342-
343251
test("BroadcastHashJoin(outer) metrics") {
344252
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
345253
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")

0 commit comments

Comments
 (0)