Skip to content

[SPARK-14475] Propagate user-defined context from driver to executors #12248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}

/**
* Set a local property that affects jobs submitted from this thread, such as the
* Spark fair scheduler pool.
* Set a local property that affects jobs submitted from this thread, such as the Spark fair
* scheduler pool. User-defined properties may also be set here. These properties are propagated
* through to worker tasks and can be accessed there via
* [[org.apache.spark.TaskContext#getLocalProperty]].
*/
def setLocalProperty(key: String, value: String) {
if (value == null) {
Expand Down
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.io.Serializable
import java.util.Properties

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
Expand Down Expand Up @@ -64,7 +65,7 @@ object TaskContext {
* An empty task context that does not represent an actual task.
*/
private[spark] def empty(): TaskContextImpl = {
new TaskContextImpl(0, 0, 0, 0, null, null)
new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
}

}
Expand Down Expand Up @@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable {
*/
def taskAttemptId(): Long

/**
* Get a local property set upstream in the driver, or null if it is missing. See also
* [[org.apache.spark.SparkContext.setLocalProperty]].
*/
def getLocalProperty(key: String): String

@DeveloperApi
def taskMetrics(): TaskMetrics

Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark

import java.util.Properties

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.executor.TaskMetrics
Expand All @@ -32,6 +34,7 @@ private[spark] class TaskContextImpl(
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends TaskContext
Expand Down Expand Up @@ -118,6 +121,8 @@ private[spark] class TaskContextImpl(

override def isInterrupted(): Boolean = interrupted

override def getLocalProperty(key: String): String = localProperties.getProperty(key)

override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)

Expand Down
17 changes: 16 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -206,9 +207,16 @@ private[spark] class Executor(
startGCTime = computeTotalGcTime()

try {
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
val (taskFiles, taskJars, taskProps, taskBytes) =
Task.deserializeWithDependencies(serializedTask)

// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
Executor.taskDeserializationProps.set(taskProps)

updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
Expand Down Expand Up @@ -506,3 +514,10 @@ private[spark] class Executor(
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
}
}

private[spark] object Executor {
// This is reserved for internal use by components that need to read task properties before a
// task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
// used instead.
val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
}
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
taskBinary, part, locs, stage.internalAccumulators)
taskBinary, part, locs, stage.internalAccumulators, properties)
}

case stage: ResultStage =>
Expand All @@ -1046,7 +1046,7 @@ class DAGScheduler(
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
taskBinary, part, locs, id, stage.internalAccumulators)
taskBinary, part, locs, id, properties, stage.internalAccumulators)
}
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.scheduler

import java.io._
import java.nio.ByteBuffer
import java.util.Properties

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
Expand All @@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
* @param localProperties copy of thread-local properties set by the user on the driver side.
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
Expand All @@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
localProperties: Properties,
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.Properties

import scala.language.existentials

Expand All @@ -42,20 +43,22 @@ import org.apache.spark.shuffle.ShuffleWriter
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
* @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
_initialAccums: Seq[Accumulator[_]])
extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
_initialAccums: Seq[Accumulator[_]],
localProperties: Properties)
extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can avoid making empty Properties all over ... an Option[Properties]? a setter that is called only where needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Properties objects are kind of analogous to Maps and I think that Option[Map] would be kind of a weird type in the same sense that Option[Set] (or any other collection type) is usually kind a weird code-smell So, this is fine with me as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed safer to make it required. I can change this to an option if you think creating a Properties each time is too much overhead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I suppose allocating the empty map/properties object isn't that expensive.

}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
20 changes: 17 additions & 3 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.scheduler

import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable.HashMap

Expand Down Expand Up @@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
* @param initialAccumulators initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
* @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
val initialAccumulators: Seq[Accumulator[_]],
@transient var localProperties: Properties) extends Serializable {

/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
Expand All @@ -71,6 +74,7 @@ private[spark] abstract class Task[T](
taskAttemptId,
attemptNumber,
taskMemoryManager,
localProperties,
metricsSystem,
initialAccumulators)
TaskContext.setTaskContext(context)
Expand Down Expand Up @@ -206,6 +210,11 @@ private[spark] object Task {
dataOut.writeLong(timestamp)
}

// Write the task properties separately so it is available before full task deserialization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the properties aren't transient in Task, I guess this means that we'll write them out twice. If we want to avoid this, we can make localProperties into a @transient var which is private[spark] then re-set the field after deserializing the task. Tasks are send to executors using broadcast variables, so the extra space only makes a different for the first task from a stage that's run on an executor.

As a result, if we think that these serialized properties will typically be small then the extra space savings probably aren't a huge deal, but if we want to heavily optimize then we can do the var trick.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

val propBytes = Utils.serialize(task.localProperties)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why not serializer.serialize(..)? This is fine, but just wondering.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't sure how to deserialize on the Executor side. Perhap env.serializer there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point. Utils.serialize is fine here, since it doesn't matter whether we use a custom serializer here and because today it's always going to be JavaSerializer anyways.

dataOut.writeInt(propBytes.length)
dataOut.write(propBytes)

// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task)
Expand All @@ -221,7 +230,7 @@ private[spark] object Task {
* @return (taskFiles, taskJars, taskBytes)
*/
def deserializeWithDependencies(serializedTask: ByteBuffer)
: (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
: (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = {

val in = new ByteBufferInputStream(serializedTask)
val dataIn = new DataInputStream(in)
Expand All @@ -240,8 +249,13 @@ private[spark] object Task {
taskJars(dataIn.readUTF()) = dataIn.readLong()
}

val propLength = dataIn.readInt()
val propBytes = new Array[Byte](propLength)
dataIn.readFully(propBytes, 0, propLength)
val taskProps = Utils.deserialize[Properties](propBytes)

// Create a sub-buffer for the rest of the data, which is the serialized Task object
val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
(taskFiles, taskJars, subBuffer)
(taskFiles, taskJars, taskProps, subBuffer)
}
}
5 changes: 3 additions & 2 deletions core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.util.Properties
import java.util.concurrent.Semaphore
import javax.annotation.concurrent.GuardedBy

Expand Down Expand Up @@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
// Now we're on the executors.
// Deserialize the task and assert that its accumulators are zero'ed out.
val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
val taskDeser = serInstance.deserialize[DummyTask](
taskBytes, Thread.currentThread.getContextClassLoader)
// Assert that executors see only zeros
Expand Down Expand Up @@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener {
private[spark] class DummyTask(
val internalAccums: Seq[Accumulator[_]],
val externalAccums: Seq[Accumulator[_]])
extends Task[Int](0, 0, 0, internalAccums) {
extends Task[Int](0, 0, 0, internalAccums, new Properties) {
override def runTask(c: TaskContext): Int = 1
}
7 changes: 4 additions & 3 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.util.Properties
import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}

import org.scalatest.Matchers
Expand Down Expand Up @@ -335,15 +336,15 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC

// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data1 = (1 to 10).map { x => x -> x}

// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val data2 = (11 to 20).map { x => x -> x}

Expand Down Expand Up @@ -372,7 +373,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}

val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem,
InternalAccumulator.create(sc)))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.memory

import java.util.Properties

import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}

/**
Expand All @@ -31,6 +33,7 @@ object MemoryTestingUtils {
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
localProperties = new Properties,
metricsSystem = env.metricsSystem)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.scheduler

import java.util.Properties

import org.apache.spark.TaskContext

class FakeTask(
stageId: Int,
prefLocs: Seq[TaskLocation] = Nil)
extends Task[Int](stageId, 0, 0, Seq.empty) {
extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
package org.apache.spark.scheduler

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import java.util.Properties

import org.apache.spark.TaskContext

/**
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {

override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
Expand Down
Loading