Skip to content

[SPARK-16139][TEST] Add logging functionality for leaked threads in tests #19893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,53 @@ import org.apache.spark.util.AccumulatorContext

/**
* Base abstract class for all unit tests in Spark for handling common functionality.
*
* Thread audit happens normally here automatically when a new test suite created.
* The only prerequisite for that is that the test class must extend [[SparkFunSuite]].
*
* It is possible to override the default thread audit behavior by setting enableAutoThreadAudit
* to false and manually calling the audit methods, if desired. For example:
*
* class MyTestSuite extends SparkFunSuite {
*
* override val enableAutoThreadAudit = false
*
* protected override def beforeAll(): Unit = {
* doThreadPreAudit()
* super.beforeAll()
* }
*
* protected override def afterAll(): Unit = {
* super.afterAll()
* doThreadPostAudit()
* }
* }
*/
abstract class SparkFunSuite
extends FunSuite
with BeforeAndAfterAll
with ThreadAudit
with Logging {
// scalastyle:on

protected val enableAutoThreadAudit = true

protected override def beforeAll(): Unit = {
if (enableAutoThreadAudit) {
doThreadPreAudit()
}
super.beforeAll()
}

protected override def afterAll(): Unit = {
try {
// Avoid leaking map entries in tests that use accumulators without SparkContext
AccumulatorContext.clear()
} finally {
super.afterAll()
if (enableAutoThreadAudit) {
doThreadPostAudit()
}
}
}

Expand Down
99 changes: 99 additions & 0 deletions core/src/test/scala/org/apache/spark/ThreadAudit.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging

/**
* Thread audit for test suites.
*/
trait ThreadAudit extends Logging {

val threadWhiteList = Set(
/**
* Netty related internal threads.
* These are excluded because their lifecycle is handled by the netty itself
* and spark has no explicit effect on them.
*/
"netty.*",

/**
* Netty related internal threads.
* A Single-thread singleton EventExecutor inside netty which creates such threads.
* These are excluded because their lifecycle is handled by the netty itself
* and spark has no explicit effect on them.
*/
"globalEventExecutor.*",

/**
* Netty related internal threads.
* Checks if a thread is alive periodically and runs a task when a thread dies.
* These are excluded because their lifecycle is handled by the netty itself
* and spark has no explicit effect on them.
*/
"threadDeathWatcher.*",

/**
* During [[SparkContext]] creation [[org.apache.spark.rpc.netty.NettyRpcEnv]]
* creates event loops. One is wrapped inside
* [[org.apache.spark.network.server.TransportServer]]
* the other one is inside [[org.apache.spark.network.client.TransportClient]].
* The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]].
* Manually checked and all of them stopped properly.
*/
"rpc-client.*",
"rpc-server.*",

/**
* During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside
* [[org.apache.spark.network.server.TransportServer]]
* the other one is inside [[org.apache.spark.network.client.TransportClient]].
* The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]].
* Manually checked and all of them stopped properly.
*/
"shuffle-client.*",
"shuffle-server.*"
)
private var threadNamesSnapshot: Set[String] = Set.empty

protected def doThreadPreAudit(): Unit = {
threadNamesSnapshot = runningThreadNames()
}

protected def doThreadPostAudit(): Unit = {
val shortSuiteName = this.getClass.getName.replaceAll("org.apache.spark", "o.a.s")

if (threadNamesSnapshot.nonEmpty) {
val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot)
.filterNot { s => threadWhiteList.exists(s.matches(_)) }
if (remainingThreadNames.nonEmpty) {
logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " +
s"thread names: ${remainingThreadNames.mkString(", ")} =====\n")
}
} else {
logWarning("\n\n===== THREAD AUDIT POST ACTION CALLED " +
s"WITHOUT PRE ACTION IN SUITE $shortSuiteName =====\n")
}
}

private def runningThreadNames(): Set[String] = {
Thread.getAllStackTraces.keySet().asScala.map(_.getName).toSet
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val conf = new SparkConf().set("spark.speculation", "true")
sc = new SparkContext("local", "test", conf)

val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"))
sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was this change about? not shadowing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here originally the newly created instance was stored in a local variable which was never saved in member and freed properly. With this change the afterEach method stops it and frees up the resources.

sched.initialize(new FakeSchedulerBackend() {
override def killTask(
taskId: Long,
Expand All @@ -709,6 +709,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
}
}
}
sched.dagScheduler.stop()
sched.setDAGScheduler(dagScheduler)

val singleTask = new ShuffleMapTask(0, 0, null, new Partition {
Expand Down Expand Up @@ -754,7 +755,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
sc.conf.set("spark.speculation", "true")

var killTaskCalled = false
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"),
sched = new FakeTaskScheduler(sc, ("exec1", "host1"),
("exec2", "host2"), ("exec3", "host3"))
sched.initialize(new FakeSchedulerBackend() {
override def killTask(
Expand Down Expand Up @@ -789,6 +790,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
}
}
}
sched.dagScheduler.stop()
sched.setDAGScheduler(dagScheduler)

val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0,
Expand Down Expand Up @@ -1183,6 +1185,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
sc = new SparkContext("local", "test")
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val mockDAGScheduler = mock(classOf[DAGScheduler])
sched.dagScheduler.stop()
sched.dagScheduler = mockDAGScheduler
val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SessionStateSuite extends SparkFunSuite
protected var activeSession: SparkSession = _

override def beforeAll(): Unit = {
super.beforeAll()
activeSession = SparkSession.builder().master("local").getOrCreate()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
private var targetPartitionSchema: StructType = _

override def beforeAll(): Unit = {
super.beforeAll()
targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int)
targetPartitionSchema = new StructType()
.add("b", IntegerType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,25 @@

package org.apache.spark.sql.test

trait SharedSQLContext extends SQLTestUtils with SharedSparkSession
trait SharedSQLContext extends SQLTestUtils with SharedSparkSession {

/**
* Suites extending [[SharedSQLContext]] are sharing resources (eg. SparkSession) in their tests.
* That trait initializes the spark session in its [[beforeAll()]] implementation before the
* automatic thread snapshot is performed, so the audit code could fail to report threads leaked
* by that shared session.
*
* The behavior is overridden here to take the snapshot before the spark session is initialized.
*/
override protected val enableAutoThreadAudit = false

protected override def beforeAll(): Unit = {
doThreadPreAudit()
super.beforeAll()
}

protected override def afterAll(): Unit = {
super.afterAll()
doThreadPostAudit()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}

class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach {

override protected val enableAutoThreadAudit = false
private var sc: SparkContext = null
private var hc: HiveContext = null

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class HiveSessionStateSuite extends SessionStateSuite

override def beforeAll(): Unit = {
// Reuse the singleton session
super.beforeAll()
activeSession = spark
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class HiveSparkSubmitSuite
with BeforeAndAfterEach
with ResetSystemProperties {

override protected val enableAutoThreadAudit = false

// TODO: rewrite these or mark them as slow tests to be run sparingly

override def beforeEach() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class HiveClientSuite(version: String)
}

override def beforeAll() {
super.beforeAll()
client = init(true)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.hive.HiveUtils

private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite {
override protected val enableAutoThreadAudit = false
protected var client: HiveClient = null

protected def buildClient(hadoopConf: Configuration): HiveClient = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils}
@ExtendedHiveTest
class VersionsSuite extends SparkFunSuite with Logging {

override protected val enableAutoThreadAudit = false

import HiveClientBuilder.buildClient

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution}
abstract class HiveComparisonTest
extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen {

override protected val enableAutoThreadAudit = false

/**
* Path to the test datasets. We find this by looking up "hive-test-path-helper.txt" file.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
import testImplicits._

override protected val enableAutoThreadAudit = false
override val dataSourceName: String =
classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.hive.client.HiveClient


trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
override protected val enableAutoThreadAudit = false
protected val spark: SparkSession = TestHive.sparkSession
protected val hiveContext: TestHiveContext = TestHive
protected val hiveClient: HiveClient =
Expand Down