@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
23
23
import org .scalatest .BeforeAndAfterAll
24
24
25
25
import org .apache .spark .rdd .RDD
26
+ import org .apache .spark .streaming .util .OpenHashMapBasedStateMap
26
27
import org .apache .spark .streaming .{Time , State }
27
28
import org .apache .spark .{HashPartitioner , SparkConf , SparkContext , SparkFunSuite }
28
29
@@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
52
53
assert(rdd.partitioner === Some (partitioner))
53
54
}
54
55
56
+ test(" updating state and generating emitted data in TrackStateRecord" ) {
57
+
58
+ val initialTime = 1000L
59
+ val updatedTime = 2000L
60
+ val thresholdTime = 1500L
61
+ @ volatile var functionCalled = false
62
+
63
+ /**
64
+ * Assert that applying given data on a prior record generates correct updated record, with
65
+ * correct state map and emitted data
66
+ */
67
+ def assertRecordUpdate (
68
+ initStates : Iterable [Int ],
69
+ data : Iterable [String ],
70
+ expectedStates : Iterable [(Int , Long )],
71
+ timeoutThreshold : Option [Long ] = None ,
72
+ removeTimedoutData : Boolean = false ,
73
+ expectedOutput : Iterable [Int ] = None ,
74
+ expectedTimingOutStates : Iterable [Int ] = None ,
75
+ expectedRemovedStates : Iterable [Int ] = None
76
+ ): Unit = {
77
+ val initialStateMap = new OpenHashMapBasedStateMap [String , Int ]()
78
+ initStates.foreach { s => initialStateMap.put(" key" , s, initialTime) }
79
+ functionCalled = false
80
+ val record = TrackStateRDDRecord [String , Int , Int ](initialStateMap, Seq .empty)
81
+ val dataIterator = data.map { v => (" key" , v) }.iterator
82
+ val removedStates = new ArrayBuffer [Int ]
83
+ val timingOutStates = new ArrayBuffer [Int ]
84
+ /**
85
+ * Tracking function that updates/removes state based on instructions in the data, and
86
+ * return state (when instructed or when state is timing out).
87
+ */
88
+ def testFunc (t : Time , key : String , data : Option [String ], state : State [Int ]): Option [Int ] = {
89
+ functionCalled = true
90
+
91
+ assert(t.milliseconds === updatedTime, " tracking func called with wrong time" )
92
+
93
+ data match {
94
+ case Some (" noop" ) =>
95
+ None
96
+ case Some (" get-state" ) =>
97
+ Some (state.getOption().getOrElse(- 1 ))
98
+ case Some (" update-state" ) =>
99
+ if (state.exists) state.update(state.get + 1 ) else state.update(0 )
100
+ None
101
+ case Some (" remove-state" ) =>
102
+ removedStates += state.get()
103
+ state.remove()
104
+ None
105
+ case None =>
106
+ assert(state.isTimingOut() === true , " State is not timing out when data = None" )
107
+ timingOutStates += state.get()
108
+ None
109
+ case _ =>
110
+ fail(" Unexpected test data" )
111
+ }
112
+ }
113
+
114
+ val updatedRecord = TrackStateRDDRecord .updateRecordWithData[String , String , Int , Int ](
115
+ Some (record), dataIterator, testFunc,
116
+ Time (updatedTime), timeoutThreshold, removeTimedoutData)
117
+
118
+ val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
119
+ assert(updatedStateData.toSet === expectedStates.toSet,
120
+ " states do not match after updating the TrackStateRecord" )
121
+
122
+ assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
123
+ " emitted data do not match after updating the TrackStateRecord" )
124
+
125
+ assert(timingOutStates.toSet === expectedTimingOutStates.toSet, " timing out states do not " +
126
+ " match those that were expected to do so while updating the TrackStateRecord" )
127
+
128
+ assert(removedStates.toSet === expectedRemovedStates.toSet, " removed states do not " +
129
+ " match those that were expected to do so while updating the TrackStateRecord" )
130
+
131
+ }
132
+
133
+ // No data, no state should be changed, function should not be called,
134
+ assertRecordUpdate(initStates = Nil , data = None , expectedStates = Nil )
135
+ assert(functionCalled === false )
136
+ assertRecordUpdate(initStates = Seq (0 ), data = None , expectedStates = Seq ((0 , initialTime)))
137
+ assert(functionCalled === false )
138
+
139
+ // Data present, function should be called irrespective of whether state exists
140
+ assertRecordUpdate(initStates = Seq (0 ), data = Seq (" noop" ),
141
+ expectedStates = Seq ((0 , initialTime)))
142
+ assert(functionCalled === true )
143
+ assertRecordUpdate(initStates = None , data = Some (" noop" ), expectedStates = None )
144
+ assert(functionCalled === true )
145
+
146
+ // Function called with right state data
147
+ assertRecordUpdate(initStates = None , data = Seq (" get-state" ),
148
+ expectedStates = None , expectedOutput = Seq (- 1 ))
149
+ assertRecordUpdate(initStates = Seq (123 ), data = Seq (" get-state" ),
150
+ expectedStates = Seq ((123 , initialTime)), expectedOutput = Seq (123 ))
151
+
152
+ // Update state and timestamp, when timeout not present
153
+ assertRecordUpdate(initStates = Nil , data = Seq (" update-state" ),
154
+ expectedStates = Seq ((0 , updatedTime)))
155
+ assertRecordUpdate(initStates = Seq (0 ), data = Seq (" update-state" ),
156
+ expectedStates = Seq ((1 , updatedTime)))
157
+
158
+ // Remove state
159
+ assertRecordUpdate(initStates = Seq (345 ), data = Seq (" remove-state" ),
160
+ expectedStates = Nil , expectedRemovedStates = Seq (345 ))
161
+
162
+ // State strictly older than timeout threshold should be timed out
163
+ assertRecordUpdate(initStates = Seq (123 ), data = Nil ,
164
+ timeoutThreshold = Some (initialTime), removeTimedoutData = true ,
165
+ expectedStates = Seq ((123 , initialTime)), expectedTimingOutStates = Nil )
166
+
167
+ assertRecordUpdate(initStates = Seq (123 ), data = Nil ,
168
+ timeoutThreshold = Some (initialTime + 1 ), removeTimedoutData = true ,
169
+ expectedStates = Nil , expectedTimingOutStates = Seq (123 ))
170
+
171
+ // State should not be timed out after it has received data
172
+ assertRecordUpdate(initStates = Seq (123 ), data = Seq (" noop" ),
173
+ timeoutThreshold = Some (initialTime + 1 ), removeTimedoutData = true ,
174
+ expectedStates = Seq ((123 , updatedTime)), expectedTimingOutStates = Nil )
175
+ assertRecordUpdate(initStates = Seq (123 ), data = Seq (" remove-state" ),
176
+ timeoutThreshold = Some (initialTime + 1 ), removeTimedoutData = true ,
177
+ expectedStates = Nil , expectedTimingOutStates = Nil , expectedRemovedStates = Seq (123 ))
178
+
179
+ }
180
+
55
181
test(" states generated by TrackStateRDD" ) {
56
182
val initStates = Seq ((" k1" , 0 ), (" k2" , 0 ))
57
183
val initTime = 123
@@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
148
274
val rdd7 = testStateUpdates( // should remove k2's state
149
275
rdd6, Seq ((" k2" , 2 ), (" k0" , 2 ), (" k3" , 1 )), Set ((" k3" , 0 , updateTime)))
150
276
151
- val rdd8 = testStateUpdates(
152
- rdd7, Seq ((" k3" , 2 )), Set () //
153
- )
277
+ val rdd8 = testStateUpdates( // should remove k3's state
278
+ rdd7, Seq ((" k3" , 2 )), Set ())
154
279
}
155
280
156
281
/** Assert whether the `trackStateByKey` operation generates expected results */
@@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
176
301
177
302
// Persist to make sure that it gets computed only once and we can track precisely how many
178
303
// state keys the computing touched
179
- newStateRDD.persist()
304
+ newStateRDD.persist().count()
180
305
assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
181
306
newStateRDD
182
307
}
@@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
188
313
expectedEmittedRecords : Set [T ]): Unit = {
189
314
val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
190
315
val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
191
- assert(states === expectedStates, " states after track state operation were not as expected" )
316
+ assert(states === expectedStates,
317
+ " states after track state operation were not as expected" )
192
318
assert(emittedRecords === expectedEmittedRecords,
193
319
" emitted records after track state operation were not as expected" )
194
320
}
0 commit comments