Skip to content

Commit 28fde53

Browse files
author
Davies Liu
committed
Merge branch 'master' of github.com:apache/spark into python_tests
2 parents 945a2b5 + 3ae37b9 commit 28fde53

File tree

28 files changed

+733
-401
lines changed

28 files changed

+733
-401
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 296 additions & 209 deletions
Large diffs are not rendered by default.

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.io.File
2121
import java.lang.management.ManagementFactory
2222
import java.net.URL
2323
import java.nio.ByteBuffer
24-
import java.util.concurrent.ConcurrentHashMap
24+
import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit}
2525

2626
import scala.collection.JavaConversions._
2727
import scala.collection.mutable.{ArrayBuffer, HashMap}
@@ -60,8 +60,6 @@ private[spark] class Executor(
6060

6161
private val conf = env.conf
6262

63-
@volatile private var isStopped = false
64-
6563
// No ip or host:port - just hostname
6664
Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
6765
// must not have port specified.
@@ -114,6 +112,10 @@ private[spark] class Executor(
114112
// Maintains the list of running tasks.
115113
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
116114

115+
// Executor for the heartbeat task.
116+
private val heartbeater = Executors.newSingleThreadScheduledExecutor(
117+
Utils.namedThreadFactory("driver-heartbeater"))
118+
117119
startDriverHeartbeater()
118120

119121
def launchTask(
@@ -138,7 +140,8 @@ private[spark] class Executor(
138140
def stop(): Unit = {
139141
env.metricsSystem.report()
140142
env.rpcEnv.stop(executorEndpoint)
141-
isStopped = true
143+
heartbeater.shutdown()
144+
heartbeater.awaitTermination(10, TimeUnit.SECONDS)
142145
threadPool.shutdown()
143146
if (!isLocal) {
144147
env.stop()
@@ -432,23 +435,17 @@ private[spark] class Executor(
432435
}
433436

434437
/**
435-
* Starts a thread to report heartbeat and partial metrics for active tasks to driver.
436-
* This thread stops running when the executor is stopped.
438+
* Schedules a task to report heartbeat and partial metrics for active tasks to driver.
437439
*/
438440
private def startDriverHeartbeater(): Unit = {
439441
val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")
440-
val thread = new Thread() {
441-
override def run() {
442-
// Sleep a random interval so the heartbeats don't end up in sync
443-
Thread.sleep(intervalMs + (math.random * intervalMs).asInstanceOf[Int])
444-
while (!isStopped) {
445-
reportHeartBeat()
446-
Thread.sleep(intervalMs)
447-
}
448-
}
442+
443+
// Wait a random interval so the heartbeats don't end up in sync
444+
val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int]
445+
446+
val heartbeatTask = new Runnable() {
447+
override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat())
449448
}
450-
thread.setDaemon(true)
451-
thread.setName("driver-heartbeater")
452-
thread.start()
449+
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
453450
}
454451
}

core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ private[nio] class ConnectionManager(
8686
conf.get("spark.network.timeout", "120s"))
8787

8888
// Get the thread counts from the Spark Configuration.
89-
//
89+
//
9090
// Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value,
9191
// we only query for the minimum value because we are using LinkedBlockingDeque.
92-
//
93-
// The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is
92+
//
93+
// The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is
9494
// an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min"
9595
// parameter is necessary.
9696
private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20)
@@ -989,6 +989,7 @@ private[nio] class ConnectionManager(
989989

990990
def stop() {
991991
ackTimeoutMonitor.stop()
992+
selector.wakeup()
992993
selectorThread.interrupt()
993994
selectorThread.join()
994995
selector.close()

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,10 @@ private[spark] class TaskSchedulerImpl(
142142

143143
if (!isLocal && conf.getBoolean("spark.speculation", false)) {
144144
logInfo("Starting speculative execution thread")
145-
import sc.env.actorSystem.dispatcher
146145
sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds,
147146
SPECULATION_INTERVAL_MS milliseconds) {
148147
Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() }
149-
}
148+
}(sc.env.actorSystem.dispatcher)
150149
}
151150
}
152151

@@ -394,7 +393,7 @@ private[spark] class TaskSchedulerImpl(
394393

395394
def error(message: String) {
396395
synchronized {
397-
if (activeTaskSets.size > 0) {
396+
if (activeTaskSets.nonEmpty) {
398397
// Have each task set throw a SparkException with the error
399398
for ((taskSetId, manager) <- activeTaskSets) {
400399
try {
@@ -407,8 +406,7 @@ private[spark] class TaskSchedulerImpl(
407406
// No task sets are active but we still got an error. Just exit since this
408407
// must mean the error is during registration.
409408
// It might be good to do something smarter here in the future.
410-
logError("Exiting due to error from cluster scheduler: " + message)
411-
System.exit(1)
409+
throw new SparkException(s"Exiting due to error from cluster scheduler: $message")
412410
}
413411
}
414412
}

core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend(
118118
notifyContext()
119119
if (!stopping) {
120120
logError("Application has been killed. Reason: " + reason)
121-
scheduler.error(reason)
122-
// Ensure the application terminates, as we can no longer run jobs.
123-
sc.stop()
121+
try {
122+
scheduler.error(reason)
123+
} finally {
124+
// Ensure the application terminates, as we can no longer run jobs.
125+
sc.stop()
126+
}
124127
}
125128
}
126129

core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ package org.apache.spark.scheduler.local
2020
import java.nio.ByteBuffer
2121
import java.util.concurrent.{Executors, TimeUnit}
2222

23-
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv}
24-
import org.apache.spark.util.Utils
25-
import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
23+
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState}
2624
import org.apache.spark.TaskState.TaskState
2725
import org.apache.spark.executor.{Executor, ExecutorBackend}
26+
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv}
2827
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
28+
import org.apache.spark.util.Utils
2929

3030
private case class ReviveOffers()
3131

@@ -71,11 +71,15 @@ private[spark] class LocalEndpoint(
7171

7272
case KillTask(taskId, interruptThread) =>
7373
executor.killTask(taskId, interruptThread)
74+
}
7475

76+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
7577
case StopExecutor =>
7678
executor.stop()
79+
context.reply(true)
7780
}
7881

82+
7983
def reviveOffers() {
8084
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
8185
val tasks = scheduler.resourceOffers(offers).flatten
@@ -104,8 +108,11 @@ private[spark] class LocalEndpoint(
104108
* master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks
105109
* on a single Executor (created by the LocalBackend) running locally.
106110
*/
107-
private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
108-
extends SchedulerBackend with ExecutorBackend {
111+
private[spark] class LocalBackend(
112+
conf: SparkConf,
113+
scheduler: TaskSchedulerImpl,
114+
val totalCores: Int)
115+
extends SchedulerBackend with ExecutorBackend with Logging {
109116

110117
private val appId = "local-" + System.currentTimeMillis
111118
var localEndpoint: RpcEndpointRef = null
@@ -116,7 +123,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
116123
}
117124

118125
override def stop() {
119-
localEndpoint.send(StopExecutor)
126+
localEndpoint.sendWithReply(StopExecutor)
120127
}
121128

122129
override def reviveOffers() {

core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ private[spark] trait ActorLogReceive {
4343

4444
private val _receiveWithLogging = receiveWithLogging
4545

46-
override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o)
46+
override def isDefinedAt(o: Any): Boolean = {
47+
val handled = _receiveWithLogging.isDefinedAt(o)
48+
if (!handled) {
49+
log.debug(s"Received unexpected actor system event: $o")
50+
}
51+
handled
52+
}
4753

4854
override def apply(o: Any): Unit = {
4955
if (log.isDebugEnabled) {

core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,13 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit
5656
// Min < 0
5757
val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1")
5858
intercept[SparkException] { contexts += new SparkContext(conf1) }
59-
SparkEnv.get.stop()
60-
SparkContext.clearActiveContext()
6159

6260
// Max < 0
6361
val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1")
6462
intercept[SparkException] { contexts += new SparkContext(conf2) }
65-
SparkEnv.get.stop()
66-
SparkContext.clearActiveContext()
6763

6864
// Both min and max, but min > max
6965
intercept[SparkException] { createSparkContext(2, 1) }
70-
SparkEnv.get.stop()
71-
SparkContext.clearActiveContext()
7266

7367
// Both min and max, and min == max
7468
val sc1 = createSparkContext(1, 1)

examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ object DirectKafkaWordCount {
4141
| <brokers> is a list of one or more Kafka brokers
4242
| <topics> is a list of one or more kafka topics to consume from
4343
|
44-
"""".stripMargin)
44+
""".stripMargin)
4545
System.exit(1)
4646
}
4747

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskC
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.util.NextIterator
2525

26-
import java.util.Properties
2726
import kafka.api.{FetchRequestBuilder, FetchResponse}
2827
import kafka.common.{ErrorMapping, TopicAndPartition}
29-
import kafka.consumer.{ConsumerConfig, SimpleConsumer}
28+
import kafka.consumer.SimpleConsumer
3029
import kafka.message.{MessageAndMetadata, MessageAndOffset}
3130
import kafka.serializer.Decoder
3231
import kafka.utils.VerifiableProperties

mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.util.UUID
2525
private[ml] trait Identifiable extends Serializable {
2626

2727
/**
28-
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
28+
* A unique id for the object. The default implementation concatenates the class name, "_", and 8
2929
* random hex chars.
3030
*/
3131
private[ml] val uid: String =

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,16 @@ object Vectors {
227227
* @param elements vector elements in (index, value) pairs.
228228
*/
229229
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
230-
require(size > 0)
230+
require(size > 0, "The size of the requested sparse vector must be greater than 0.")
231231

232232
val (indices, values) = elements.sortBy(_._1).unzip
233233
var prev = -1
234234
indices.foreach { i =>
235235
require(prev < i, s"Found duplicate indices: $i.")
236236
prev = i
237237
}
238-
require(prev < size)
238+
require(prev < size, s"You may not write an element to index $prev because the declared " +
239+
s"size of your vector is $size")
239240

240241
new SparseVector(size, indices.toArray, values.toArray)
241242
}
@@ -309,7 +310,8 @@ object Vectors {
309310
* @return norm in L^p^ space.
310311
*/
311312
def norm(vector: Vector, p: Double): Double = {
312-
require(p >= 1.0)
313+
require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " +
314+
s"You specified p=$p.")
313315
val values = vector match {
314316
case DenseVector(vs) => vs
315317
case SparseVector(n, ids, vs) => vs
@@ -360,7 +362,8 @@ object Vectors {
360362
* @return squared distance between two Vectors.
361363
*/
362364
def sqdist(v1: Vector, v2: Vector): Double = {
363-
require(v1.size == v2.size, "vector dimension mismatch")
365+
require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" +
366+
s"=${v2.size}.")
364367
var squaredDistance = 0.0
365368
(v1, v2) match {
366369
case (v1: SparseVector, v2: SparseVector) =>
@@ -518,7 +521,9 @@ class SparseVector(
518521
val indices: Array[Int],
519522
val values: Array[Double]) extends Vector {
520523

521-
require(indices.length == values.length)
524+
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
525+
s" indices match the dimension of the values. You provided ${indices.size} indices and " +
526+
s" ${values.size} values.")
522527

523528
override def toString: String =
524529
"(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]"))

mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
package org.apache.spark.ml.param
1919

20+
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
21+
2022
/** A subclass of Params for testing. */
21-
class TestParams extends Params {
23+
class TestParams extends Params with HasMaxIter with HasInputCol {
2224

23-
val maxIter = new IntParam(this, "maxIter", "max number of iterations")
2425
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
25-
def getMaxIter: Int = getOrDefault(maxIter)
26-
27-
val inputCol = new Param[String](this, "inputCol", "input column name")
2826
def setInputCol(value: String): this.type = { set(inputCol, value); this }
29-
def getInputCol: String = getOrDefault(inputCol)
3027

3128
setDefault(maxIter -> 10)
3229

python/pyspark/ml/classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
5959
maxIter=100, regParam=0.1)
6060
"""
6161
super(LogisticRegression, self).__init__()
62+
self._setDefault(maxIter=100, regParam=0.1)
6263
kwargs = self.__init__._input_kwargs
6364
self.setParams(**kwargs)
6465

@@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
7172
Sets params for logistic regression.
7273
"""
7374
kwargs = self.setParams._input_kwargs
74-
return self._set_params(**kwargs)
75+
return self._set(**kwargs)
7576

7677
def _create_model(self, java_model):
7778
return LogisticRegressionModel(java_model)

0 commit comments

Comments
 (0)