Skip to content

SPARK-1039: Set the upper bound for retry times of in-cluster drivers #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ private[deploy] object DeployMessages {
extends DeployMessage

case class DriverStateChanged(
driverId: String,
driverID: String,
state: DriverState,
exception: Option[Exception])
extends DeployMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ private[spark] class DriverInfo(
@transient var exception: Option[Exception] = None
/* Most recent worker assigned to this driver */
@transient var worker: Option[WorkerInfo] = None

/**
* the retry times of starting a in-cluster driver
*/
@transient var retriedcountOnMaster = 0
}
70 changes: 43 additions & 27 deletions core/src/main/scala/org/apache/spark/deploy/master/Master.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ private[spark] class Master(
var nextAppNumber = 0

val appIdToUI = new HashMap[String, SparkUI]

val drivers = new HashSet[DriverInfo]
val drivers = new HashMap[String, DriverInfo]//driverid -> driverinfo
val driverAssignments = new HashMap[String, HashSet[String]]//driverid -> HashSet[workerID]
val completedDrivers = new ArrayBuffer[DriverInfo]
val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling
var nextDriverNumber = 0
Expand Down Expand Up @@ -209,12 +209,9 @@ private[spark] class Master(
val driver = createDriver(description)
persistenceEngine.addDriver(driver)
waitingDrivers += driver
drivers.add(driver)
drivers += driver.id -> driver
driverAssignments(driver.id) = new HashSet[String]
schedule()

// TODO: It might be good to instead have the submission client poll the master to determine
// the current status of the driver. For now it's simply "fire and forget".

sender ! SubmitDriverResponse(true, Some(driver.id),
s"Driver successfully submitted as ${driver.id}")
}
Expand All @@ -226,12 +223,12 @@ private[spark] class Master(
sender ! KillDriverResponse(driverId, success = false, msg)
} else {
logInfo("Asked to kill driver " + driverId)
val driver = drivers.find(_.id == driverId)
val driver = drivers.get(driverId)
driver match {
case Some(d) =>
if (waitingDrivers.contains(d)) {
waitingDrivers -= d
self ! DriverStateChanged(driverId, DriverState.KILLED, None)
self ! DriverStateChanged(d.id, DriverState.KILLED, None)
}
else {
// We just notify the worker to kill the driver here. The final bookkeeping occurs
Expand All @@ -254,7 +251,7 @@ private[spark] class Master(
}

case RequestDriverStatus(driverId) => {
(drivers ++ completedDrivers).find(_.id == driverId) match {
(drivers.values ++ completedDrivers).find(_.id == driverId) match {
case Some(driver) =>
sender ! DriverStatusResponse(found = true, Some(driver.state),
driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)
Expand Down Expand Up @@ -305,12 +302,25 @@ private[spark] class Master(
}
}

case DriverStateChanged(driverId, state, exception) => {
case DriverStateChanged(driverID, state, exception) => {
state match {
case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED =>
removeDriver(driverId, state, exception)
case DriverState.FINISHED | DriverState.KILLED =>
removeDriver(driverID, state, exception)
case DriverState.ERROR | DriverState.FAILED =>
drivers.get(driverID) match {
case Some(driver) =>
val maxRetry = conf.getInt("spark.driver.maxRetry", 0)
if ((maxRetry != 0 && driver.retriedcountOnMaster < Math.min(maxRetry, workers.size)) ||
(maxRetry == 0 && driver.retriedcountOnMaster < workers.size)) {
//recover the driver
relaunchDriver(driver)
} else {
removeDriver(driver.id, state, exception)
}
case None =>
}
case _ =>
throw new Exception(s"Received unexpected state update for driver $driverId: $state")
throw new Exception(s"Received unexpected state update for driver $driverID: $state")
}
}

Expand Down Expand Up @@ -350,7 +360,7 @@ private[spark] class Master(
}

for (driverId <- driverIds) {
drivers.find(_.id == driverId).foreach { driver =>
drivers.get(driverId).foreach { driver =>
driver.worker = Some(worker)
driver.state = DriverState.RUNNING
worker.drivers(driverId) = driver
Expand All @@ -373,7 +383,7 @@ private[spark] class Master(

case RequestMasterState => {
sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
drivers.toArray, completedDrivers.toArray, state)
drivers.values.toArray, completedDrivers.toArray, state)
}

case CheckForWorkerTimeOut => {
Expand Down Expand Up @@ -405,7 +415,7 @@ private[spark] class Master(
for (driver <- storedDrivers) {
// Here we just read in the list of drivers. Any drivers associated with now-lost workers
// will be re-launched when we detect that the worker is missing.
drivers += driver
drivers += (driver.id -> driver)
}

for (worker <- storedWorkers) {
Expand All @@ -432,14 +442,14 @@ private[spark] class Master(
apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)

// Reschedule drivers which were not claimed by any workers
drivers.filter(_.worker.isEmpty).foreach { d =>
logWarning(s"Driver ${d.id} was not found after master recovery")
if (d.desc.supervise) {
logWarning(s"Re-launching ${d.id}")
relaunchDriver(d)
drivers.filter(_._2.worker.isEmpty).foreach { d =>
logWarning(s"Driver ${d._1} was not found after master recovery")
if (d._2.desc.supervise) {
logWarning(s"Re-launching ${d._1}")
relaunchDriver(d._2)
} else {
removeDriver(d.id, DriverState.ERROR, None)
logWarning(s"Did not re-launch ${d.id} because it was not supervised")
removeDriver(d._1, DriverState.ERROR, None)
logWarning(s"Did not re-launch ${d._1} because it was not supervised")
}
}

Expand Down Expand Up @@ -468,7 +478,9 @@ private[spark] class Master(
val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers
for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) {
for (driver <- waitingDrivers) {
if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
if (worker.memoryFree >= driver.desc.mem &&
worker.coresFree >= driver.desc.cores &&
!driverAssignments(driver.id).contains(worker.id)) {
launchDriver(worker, driver)
waitingDrivers -= driver
}
Expand Down Expand Up @@ -582,6 +594,8 @@ private[spark] class Master(
def relaunchDriver(driver: DriverInfo) {
driver.worker = None
driver.state = DriverState.RELAUNCHING
//we add this value for both worker failure and program failure
driver.retriedcountOnMaster += 1
waitingDrivers += driver
schedule()
}
Expand Down Expand Up @@ -720,16 +734,18 @@ private[spark] class Master(
def launchDriver(worker: WorkerInfo, driver: DriverInfo) {
logInfo("Launching driver " + driver.id + " on worker " + worker.id)
worker.addDriver(driver)
driverAssignments(driver.id) += worker.id
driver.worker = Some(worker)
worker.actor ! LaunchDriver(driver.id, driver.desc)
driver.state = DriverState.RUNNING
}

def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) {
drivers.find(d => d.id == driverId) match {
drivers.get(driverId) match {
case Some(driver) =>
logInfo(s"Removing driver: $driverId")
drivers -= driver
drivers -= driverId
driverAssignments -= driver.id
completedDrivers += driver
persistenceEngine.removeDriver(driver)
driver.state = finalState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,29 @@ import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileUtil, Path}

import org.apache.spark.Logging
import org.apache.spark.deploy.{Command, DriverDescription}
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.deploy.{DriverDescription, Command}
import org.apache.spark.deploy.DeployMessages.DriverStateChanged
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.{DriverInfo, DriverState}
import org.apache.spark.deploy.master.DriverState.DriverState

/**
* Manages the execution of one driver, including automatically restarting the driver on failure.
*/
private[spark] class DriverRunner(
val driverId: String,
val driverDesc: DriverDescription,
val workDir: File,
val sparkHome: File,
val driverDesc: DriverDescription,
val worker: ActorRef,
val workerUrl: String)
val workerUrl: String,
val conf: SparkConf)
extends Logging {

class FailedTooManyTimesException(retryN: Int) extends Exception {
var retryCount = retryN
}

@volatile var process: Option[Process] = None
@volatile var killed = false

Expand All @@ -54,6 +59,9 @@ private[spark] class DriverRunner(
var finalException: Option[Exception] = None
var finalExitCode: Option[Int] = None

// Retry counters
var retryNum = 0

// Decoupled for testing
private[deploy] def setClock(_clock: Clock) = clock = _clock
private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper
Expand Down Expand Up @@ -97,7 +105,6 @@ private[spark] class DriverRunner(
}

finalState = Some(state)

worker ! DriverStateChanged(driverId, state, finalException)
}
}.start()
Expand Down Expand Up @@ -184,7 +191,6 @@ private[spark] class DriverRunner(
val successfulRunDuration = 5

var keepTrying = !killed

while (keepTrying) {
logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\""))

Expand All @@ -199,15 +205,19 @@ private[spark] class DriverRunner(
if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) {
waitSeconds = 1
}

if (supervise && exitCode != 0 && !killed) {
val maxRetry = conf.getInt("spark.driver.maxRetry", 0)
keepTrying = supervise && exitCode != 0 && !killed
if (keepTrying) {
retryNum += 1
if (retryNum > maxRetry)
throw new FailedTooManyTimesException(retryNum)
//sleep only when we want to retry
logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
sleeper.sleep(waitSeconds)
waitSeconds = waitSeconds * 2 // exponential back-off
} else {
finalExitCode = Some(exitCode)
}

keepTrying = supervise && exitCode != 0 && !killed
finalExitCode = Some(exitCode)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ private[spark] class Worker(

case LaunchDriver(driverId, driverDesc) => {
logInfo(s"Asked to launch driver $driverId")
val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl)
val driver = new DriverRunner(driverId, driverDesc, workDir, sparkHome, self, akkaUrl, conf)
drivers(driverId) = driver
driver.start()

Expand All @@ -286,11 +286,11 @@ private[spark] class Worker(
case DriverStateChanged(driverId, state, exception) => {
state match {
case DriverState.ERROR =>
logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}")
logWarning(s"Driver ${driverId} failed with unrecoverable exception: ${exception.get}")
case DriverState.FINISHED =>
logInfo(s"Driver $driverId exited successfully")
logInfo(s"Driver ${driverId} exited successfully")
case DriverState.KILLED =>
logInfo(s"Driver $driverId was killed by user")
logInfo(s"Driver ${driverId} was killed by user")
}
masterLock.synchronized {
master ! DriverStateChanged(driverId, state, exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.xml.Node

import akka.pattern.ask
import javax.servlet.http.HttpServletRequest
import org.json4s.JValue
import net.liftweb.json.JsonAST.JValue

import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse}
Expand Down Expand Up @@ -144,7 +144,7 @@ private[spark] class IndexPage(parent: WorkerWebUI) {

def driverRow(driver: DriverRunner): Seq[Node] = {
<tr>
<td>{driver.driverId}</td>
<td>{driver.driverDesc}</td>
<td>{driver.driverDesc.command.arguments(1)}</td>
<td>{driver.finalState.getOrElse(DriverState.RUNNING)}</td>
<td sorttable_customkey={driver.driverDesc.cores.toString}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ class JsonProtocolSuite extends FunSuite {
new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING)
}
def createDriverRunner(): DriverRunner = {
new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(),
null, "akka://worker")
new DriverRunner("driverId", createDriverDesc(), new File("workDir"), new File("sparkHome"),
null, "akka://worker", null)
}

def assertValidJson(json: JValue) {
Expand Down
Loading