Skip to content

[SPARK-11681][Streaming] Correctly update state timestamp even when state is not updated #9648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,51 @@ import org.apache.spark._
* Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
* sequence of records returned by the tracking function of `trackStateByKey`.
*/
private[streaming] case class TrackStateRDDRecord[K, S, T](
var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
private[streaming] case class TrackStateRDDRecord[K, S, E](
var stateMap: StateMap[K, S], var emittedRecords: Seq[E])

private[streaming] object TrackStateRDDRecord {
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
prevRecord: Option[TrackStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
updateFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
): TrackStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

val emittedRecords = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()

// Call the tracking function on each record in the data iterator, and accordingly
// update the states touched, and collect the data returned by the tracking function
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
emittedRecords ++= emittedRecord
}

// Get the timed out state records, call the tracking function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
emittedRecords ++= emittedRecord
newStateMap.remove(key)
}
}

TrackStateRDDRecord(newStateMap, emittedRecords)
}
}

/**
* Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
Expand Down Expand Up @@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition(
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
* @param timeoutThresholdTime The time to indicate which keys are timeout
*/
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, T]](
) extends RDD[TrackStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {

Expand All @@ -98,50 +141,24 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
}

override def compute(
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {

val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)

// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = if (prevStateRDDIterator.hasNext) {
prevStateRDDIterator.next().stateMap.copy()
} else {
new EmptyStateMap[K, S]()
}

val emittedRecords = new ArrayBuffer[T]
val wrappedState = new StateImpl[S]()

// Call the tracking function on each record in the data RDD partition, and accordingly
// update the states touched, and the data returned by the tracking function.
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
emittedRecords ++= emittedRecord
}

// If the RDD is expected to be doing a full scan of all the data in the StateMap,
// then use this opportunity to filter out those keys that have timed out.
// For each of them call the tracking function.
if (doFullScan && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
emittedRecords ++= emittedRecord
newStateMap.remove(key)
}
}

Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
val newRecord = TrackStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
trackingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
)
Iterator(newRecord)
}

override protected def getPartitions: Array[Partition] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
import org.apache.spark.streaming.{Time, State}
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}

Expand Down Expand Up @@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(rdd.partitioner === Some(partitioner))
}

test("updating state and generating emitted data in TrackStateRecord") {

val initialTime = 1000L
val updatedTime = 2000L
val thresholdTime = 1500L
@volatile var functionCalled = false

/**
* Assert that applying given data on a prior record generates correct updated record, with
* correct state map and emitted data
*/
def assertRecordUpdate(
initStates: Iterable[Int],
data: Iterable[String],
expectedStates: Iterable[(Int, Long)],
timeoutThreshold: Option[Long] = None,
removeTimedoutData: Boolean = false,
expectedOutput: Iterable[Int] = None,
expectedTimingOutStates: Iterable[Int] = None,
expectedRemovedStates: Iterable[Int] = None
): Unit = {
val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
functionCalled = false
val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
val dataIterator = data.map { v => ("key", v) }.iterator
val removedStates = new ArrayBuffer[Int]
val timingOutStates = new ArrayBuffer[Int]
/**
* Tracking function that updates/removes state based on instructions in the data, and
* return state (when instructed or when state is timing out).
*/
def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
functionCalled = true

assert(t.milliseconds === updatedTime, "tracking func called with wrong time")

data match {
case Some("noop") =>
None
case Some("get-state") =>
Some(state.getOption().getOrElse(-1))
case Some("update-state") =>
if (state.exists) state.update(state.get + 1) else state.update(0)
None
case Some("remove-state") =>
removedStates += state.get()
state.remove()
None
case None =>
assert(state.isTimingOut() === true, "State is not timing out when data = None")
timingOutStates += state.get()
None
case _ =>
fail("Unexpected test data")
}
}

val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
Some(record), dataIterator, testFunc,
Time(updatedTime), timeoutThreshold, removeTimedoutData)

val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
assert(updatedStateData.toSet === expectedStates.toSet,
"states do not match after updating the TrackStateRecord")

assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
"emitted data do not match after updating the TrackStateRecord")

assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
"match those that were expected to do so while updating the TrackStateRecord")

assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
"match those that were expected to do so while updating the TrackStateRecord")

}

// No data, no state should be changed, function should not be called,
assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
assert(functionCalled === false)
assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
assert(functionCalled === false)

// Data present, function should be called irrespective of whether state exists
assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
expectedStates = Seq((0, initialTime)))
assert(functionCalled === true)
assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
assert(functionCalled === true)

// Function called with right state data
assertRecordUpdate(initStates = None, data = Seq("get-state"),
expectedStates = None, expectedOutput = Seq(-1))
assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))

// Update state and timestamp, when timeout not present
assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
expectedStates = Seq((0, updatedTime)))
assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
expectedStates = Seq((1, updatedTime)))

// Remove state
assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
expectedStates = Nil, expectedRemovedStates = Seq(345))

// State strictly older than timeout threshold should be timed out
assertRecordUpdate(initStates = Seq(123), data = Nil,
timeoutThreshold = Some(initialTime), removeTimedoutData = true,
expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)

assertRecordUpdate(initStates = Seq(123), data = Nil,
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Seq(123))

// State should not be timed out after it has received data
assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))

}

test("states generated by TrackStateRDD") {
val initStates = Seq(("k1", 0), ("k2", 0))
val initTime = 123
Expand Down Expand Up @@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
val rdd7 = testStateUpdates( // should remove k2's state
rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))

val rdd8 = testStateUpdates(
rdd7, Seq(("k3", 2)), Set() //
)
val rdd8 = testStateUpdates( // should remove k3's state
rdd7, Seq(("k3", 2)), Set())
}

/** Assert whether the `trackStateByKey` operation generates expected results */
Expand All @@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {

// Persist to make sure that it gets computed only once and we can track precisely how many
// state keys the computing touched
newStateRDD.persist()
newStateRDD.persist().count()
assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
newStateRDD
}
Expand All @@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
expectedEmittedRecords: Set[T]): Unit = {
val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
assert(states === expectedStates, "states after track state operation were not as expected")
assert(states === expectedStates,
"states after track state operation were not as expected")
assert(emittedRecords === expectedEmittedRecords,
"emitted records after track state operation were not as expected")
}
Expand Down