Skip to content

Commit e4e46b2

Browse files
committed
[SPARK-11681][STREAMING] Correctly update state timestamp even when state is not updated
Bug: Timestamp is not updated if there is data but the corresponding state is not updated. This is wrong, and timeout is defined as "no data for a while", not "not state update for a while". Fix: Update timestamp when timestamp when timeout is specified, otherwise no need. Also refactored the code for better testability and added unit tests. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #9648 from tdas/SPARK-11681.
1 parent 7786f9c commit e4e46b2

File tree

2 files changed

+192
-49
lines changed

2 files changed

+192
-49
lines changed

streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,51 @@ import org.apache.spark._
3232
* Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
3333
* sequence of records returned by the tracking function of `trackStateByKey`.
3434
*/
35-
private[streaming] case class TrackStateRDDRecord[K, S, T](
36-
var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
35+
private[streaming] case class TrackStateRDDRecord[K, S, E](
36+
var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
37+
38+
private[streaming] object TrackStateRDDRecord {
39+
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
40+
prevRecord: Option[TrackStateRDDRecord[K, S, E]],
41+
dataIterator: Iterator[(K, V)],
42+
updateFunction: (Time, K, Option[V], State[S]) => Option[E],
43+
batchTime: Time,
44+
timeoutThresholdTime: Option[Long],
45+
removeTimedoutData: Boolean
46+
): TrackStateRDDRecord[K, S, E] = {
47+
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
48+
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
49+
50+
val emittedRecords = new ArrayBuffer[E]
51+
val wrappedState = new StateImpl[S]()
52+
53+
// Call the tracking function on each record in the data iterator, and accordingly
54+
// update the states touched, and collect the data returned by the tracking function
55+
dataIterator.foreach { case (key, value) =>
56+
wrappedState.wrap(newStateMap.get(key))
57+
val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
58+
if (wrappedState.isRemoved) {
59+
newStateMap.remove(key)
60+
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
61+
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
62+
}
63+
emittedRecords ++= emittedRecord
64+
}
65+
66+
// Get the timed out state records, call the tracking function on each and collect the
67+
// data returned
68+
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
69+
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
70+
wrappedState.wrapTiminoutState(state)
71+
val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
72+
emittedRecords ++= emittedRecord
73+
newStateMap.remove(key)
74+
}
75+
}
76+
77+
TrackStateRDDRecord(newStateMap, emittedRecords)
78+
}
79+
}
3780

3881
/**
3982
* Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
@@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition(
72115
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
73116
* @param timeoutThresholdTime The time to indicate which keys are timeout
74117
*/
75-
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
76-
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
118+
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
119+
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
77120
private var partitionedDataRDD: RDD[(K, V)],
78-
trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
121+
trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
79122
batchTime: Time,
80123
timeoutThresholdTime: Option[Long]
81-
) extends RDD[TrackStateRDDRecord[K, S, T]](
124+
) extends RDD[TrackStateRDDRecord[K, S, E]](
82125
partitionedDataRDD.sparkContext,
83126
List(
84-
new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
127+
new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
85128
new OneToOneDependency(partitionedDataRDD))
86129
) {
87130

@@ -98,50 +141,24 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
98141
}
99142

100143
override def compute(
101-
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
144+
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
102145

103146
val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
104147
val prevStateRDDIterator = prevStateRDD.iterator(
105148
stateRDDPartition.previousSessionRDDPartition, context)
106149
val dataIterator = partitionedDataRDD.iterator(
107150
stateRDDPartition.partitionedDataRDDPartition, context)
108151

109-
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
110-
val newStateMap = if (prevStateRDDIterator.hasNext) {
111-
prevStateRDDIterator.next().stateMap.copy()
112-
} else {
113-
new EmptyStateMap[K, S]()
114-
}
115-
116-
val emittedRecords = new ArrayBuffer[T]
117-
val wrappedState = new StateImpl[S]()
118-
119-
// Call the tracking function on each record in the data RDD partition, and accordingly
120-
// update the states touched, and the data returned by the tracking function.
121-
dataIterator.foreach { case (key, value) =>
122-
wrappedState.wrap(newStateMap.get(key))
123-
val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
124-
if (wrappedState.isRemoved) {
125-
newStateMap.remove(key)
126-
} else if (wrappedState.isUpdated) {
127-
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
128-
}
129-
emittedRecords ++= emittedRecord
130-
}
131-
132-
// If the RDD is expected to be doing a full scan of all the data in the StateMap,
133-
// then use this opportunity to filter out those keys that have timed out.
134-
// For each of them call the tracking function.
135-
if (doFullScan && timeoutThresholdTime.isDefined) {
136-
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
137-
wrappedState.wrapTiminoutState(state)
138-
val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
139-
emittedRecords ++= emittedRecord
140-
newStateMap.remove(key)
141-
}
142-
}
143-
144-
Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
152+
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
153+
val newRecord = TrackStateRDDRecord.updateRecordWithData(
154+
prevRecord,
155+
dataIterator,
156+
trackingFunction,
157+
batchTime,
158+
timeoutThresholdTime,
159+
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
160+
)
161+
Iterator(newRecord)
145162
}
146163

147164
override protected def getPartitions: Array[Partition] = {

streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
2323
import org.scalatest.BeforeAndAfterAll
2424

2525
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
2627
import org.apache.spark.streaming.{Time, State}
2728
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
2829

@@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
5253
assert(rdd.partitioner === Some(partitioner))
5354
}
5455

56+
test("updating state and generating emitted data in TrackStateRecord") {
57+
58+
val initialTime = 1000L
59+
val updatedTime = 2000L
60+
val thresholdTime = 1500L
61+
@volatile var functionCalled = false
62+
63+
/**
64+
* Assert that applying given data on a prior record generates correct updated record, with
65+
* correct state map and emitted data
66+
*/
67+
def assertRecordUpdate(
68+
initStates: Iterable[Int],
69+
data: Iterable[String],
70+
expectedStates: Iterable[(Int, Long)],
71+
timeoutThreshold: Option[Long] = None,
72+
removeTimedoutData: Boolean = false,
73+
expectedOutput: Iterable[Int] = None,
74+
expectedTimingOutStates: Iterable[Int] = None,
75+
expectedRemovedStates: Iterable[Int] = None
76+
): Unit = {
77+
val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
78+
initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
79+
functionCalled = false
80+
val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
81+
val dataIterator = data.map { v => ("key", v) }.iterator
82+
val removedStates = new ArrayBuffer[Int]
83+
val timingOutStates = new ArrayBuffer[Int]
84+
/**
85+
* Tracking function that updates/removes state based on instructions in the data, and
86+
* return state (when instructed or when state is timing out).
87+
*/
88+
def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
89+
functionCalled = true
90+
91+
assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
92+
93+
data match {
94+
case Some("noop") =>
95+
None
96+
case Some("get-state") =>
97+
Some(state.getOption().getOrElse(-1))
98+
case Some("update-state") =>
99+
if (state.exists) state.update(state.get + 1) else state.update(0)
100+
None
101+
case Some("remove-state") =>
102+
removedStates += state.get()
103+
state.remove()
104+
None
105+
case None =>
106+
assert(state.isTimingOut() === true, "State is not timing out when data = None")
107+
timingOutStates += state.get()
108+
None
109+
case _ =>
110+
fail("Unexpected test data")
111+
}
112+
}
113+
114+
val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
115+
Some(record), dataIterator, testFunc,
116+
Time(updatedTime), timeoutThreshold, removeTimedoutData)
117+
118+
val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
119+
assert(updatedStateData.toSet === expectedStates.toSet,
120+
"states do not match after updating the TrackStateRecord")
121+
122+
assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
123+
"emitted data do not match after updating the TrackStateRecord")
124+
125+
assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
126+
"match those that were expected to do so while updating the TrackStateRecord")
127+
128+
assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
129+
"match those that were expected to do so while updating the TrackStateRecord")
130+
131+
}
132+
133+
// No data, no state should be changed, function should not be called,
134+
assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
135+
assert(functionCalled === false)
136+
assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
137+
assert(functionCalled === false)
138+
139+
// Data present, function should be called irrespective of whether state exists
140+
assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
141+
expectedStates = Seq((0, initialTime)))
142+
assert(functionCalled === true)
143+
assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
144+
assert(functionCalled === true)
145+
146+
// Function called with right state data
147+
assertRecordUpdate(initStates = None, data = Seq("get-state"),
148+
expectedStates = None, expectedOutput = Seq(-1))
149+
assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
150+
expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
151+
152+
// Update state and timestamp, when timeout not present
153+
assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
154+
expectedStates = Seq((0, updatedTime)))
155+
assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
156+
expectedStates = Seq((1, updatedTime)))
157+
158+
// Remove state
159+
assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
160+
expectedStates = Nil, expectedRemovedStates = Seq(345))
161+
162+
// State strictly older than timeout threshold should be timed out
163+
assertRecordUpdate(initStates = Seq(123), data = Nil,
164+
timeoutThreshold = Some(initialTime), removeTimedoutData = true,
165+
expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
166+
167+
assertRecordUpdate(initStates = Seq(123), data = Nil,
168+
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
169+
expectedStates = Nil, expectedTimingOutStates = Seq(123))
170+
171+
// State should not be timed out after it has received data
172+
assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
173+
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
174+
expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
175+
assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
176+
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
177+
expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
178+
179+
}
180+
55181
test("states generated by TrackStateRDD") {
56182
val initStates = Seq(("k1", 0), ("k2", 0))
57183
val initTime = 123
@@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
148274
val rdd7 = testStateUpdates( // should remove k2's state
149275
rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
150276

151-
val rdd8 = testStateUpdates(
152-
rdd7, Seq(("k3", 2)), Set() //
153-
)
277+
val rdd8 = testStateUpdates( // should remove k3's state
278+
rdd7, Seq(("k3", 2)), Set())
154279
}
155280

156281
/** Assert whether the `trackStateByKey` operation generates expected results */
@@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
176301

177302
// Persist to make sure that it gets computed only once and we can track precisely how many
178303
// state keys the computing touched
179-
newStateRDD.persist()
304+
newStateRDD.persist().count()
180305
assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
181306
newStateRDD
182307
}
@@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
188313
expectedEmittedRecords: Set[T]): Unit = {
189314
val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
190315
val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
191-
assert(states === expectedStates, "states after track state operation were not as expected")
316+
assert(states === expectedStates,
317+
"states after track state operation were not as expected")
192318
assert(emittedRecords === expectedEmittedRecords,
193319
"emitted records after track state operation were not as expected")
194320
}

0 commit comments

Comments
 (0)