Skip to content

Commit fb5a296

Browse files
committed
Addressed PR comments
1 parent a78130d commit fb5a296

File tree

10 files changed

+318
-248
lines changed

10 files changed

+318
-248
lines changed

streaming/src/main/scala/org/apache/spark/streaming/State.scala

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,31 +29,23 @@ import org.apache.spark.annotation.Experimental
2929
*
3030
* Scala example of using `State`:
3131
* {{{
32-
* def trackStateFunc(key: String, data: Option[Int], wrappedState: State[Int]): Option[Int] = {
33-
*
32+
* // A tracking function that maintains an integer state and return a String
33+
* def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
3434
* // Check if state exists
3535
* if (state.exists) {
36-
*
37-
* val existingState = wrappedState.get // Get the existing state
38-
*
39-
* val shouldRemove = ... // Decide whether to remove the state
40-
*
36+
* val existingState = state.get // Get the existing state
37+
* val shouldRemove = ... // Decide whether to remove the state
4138
* if (shouldRemove) {
42-
*
43-
* wrappedState.remove() // Remove the state
44-
*
39+
* state.remove() // Remove the state
4540
* } else {
46-
*
4741
* val newState = ...
48-
* wrappedState(newState) // Set the new state
49-
*
42+
* state.update(newState) // Set the new state
5043
* }
5144
* } else {
52-
*
5345
* val initialState = ...
5446
* state.update(initialState) // Set the initial state
55-
*
5647
* }
48+
* ... // return something
5749
* }
5850
*
5951
* }}}
@@ -98,7 +90,7 @@ sealed abstract class State[S] {
9890

9991
/**
10092
* Whether the state is timing out and going to be removed by the system after the current batch.
101-
* This timeou can occur if timeout duration has been specified in the
93+
* This timeout can occur if timeout duration has been specified in the
10294
* [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data
10395
* for that timeout duration.
10496
*/
@@ -114,16 +106,11 @@ sealed abstract class State[S] {
114106
}
115107
}
116108

117-
private[streaming]
118-
object State {
119-
implicit def toOption[S](state: State[S]): Option[S] = state.getOption()
120-
}
121-
122109
/** Internal implementation of the [[State]] interface */
123110
private[streaming] class StateImpl[S] extends State[S] {
124111

125112
private var state: S = null.asInstanceOf[S]
126-
private var defined: Boolean = true
113+
private var defined: Boolean = false
127114
private var timingOut: Boolean = false
128115
private var updated: Boolean = false
129116
private var removed: Boolean = false
@@ -134,13 +121,18 @@ private[streaming] class StateImpl[S] extends State[S] {
134121
}
135122

136123
override def get(): S = {
137-
state
124+
if (defined) {
125+
state
126+
} else {
127+
throw new NoSuchElementException("State is not set")
128+
}
138129
}
139130

140131
override def update(newState: S): Unit = {
141132
require(!removed, "Cannot update the state after it has been removed")
142133
require(!timingOut, "Cannot update the state that is timing out")
143134
state = newState
135+
defined = true
144136
updated = true
145137
}
146138

@@ -151,6 +143,8 @@ private[streaming] class StateImpl[S] extends State[S] {
151143
override def remove(): Unit = {
152144
require(!timingOut, "Cannot remove the state that is timing out")
153145
require(!removed, "Cannot remove the state that has already been removed")
146+
defined = false
147+
updated = false
154148
removed = true
155149
}
156150

@@ -167,7 +161,7 @@ private[streaming] class StateImpl[S] extends State[S] {
167161
}
168162

169163
/**
170-
* Internal method to update the state data and reset internal flags in `this`.
164+
* Update the internal data and flags in `this` to the given state option.
171165
* This method allows `this` object to be reused across many state records.
172166
*/
173167
def wrap(optionalState: Option[S]): Unit = {
@@ -186,7 +180,7 @@ private[streaming] class StateImpl[S] extends State[S] {
186180
}
187181

188182
/**
189-
* Internal method to update the state data and reset internal flags in `this`.
183+
* Update the internal data and flags in `this` to the given state that is going to be timed out.
190184
* This method allows `this` object to be reused across many state records.
191185
*/
192186
def wrapTiminoutState(newState: S): Unit = {

streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
2222
import org.apache.spark.annotation.Experimental
2323
import org.apache.spark.api.java.JavaPairRDD
2424
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.util.ClosureCleaner
2526
import org.apache.spark.{HashPartitioner, Partitioner}
2627

2728

@@ -37,28 +38,33 @@ import org.apache.spark.{HashPartitioner, Partitioner}
3738
*
3839
* Example in Scala:
3940
* {{{
40-
* val spec = StateSpec(trackingFunction).numPartitions(10)
41+
* def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
42+
* ...
43+
* }
44+
*
45+
* val spec = StateSpec.function(trackingFunction).numPartitions(10)
4146
*
4247
* val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
4348
* }}}
4449
*
4550
* Example in Java:
4651
* {{{
47-
* StateStateSpec[StateType, EmittedDataType] spec =
48-
* StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10);
52+
* StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
53+
* StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
54+
* .numPartition(10);
4955
*
5056
* JavaDStream[EmittedDataType] emittedRecordDStream =
5157
* javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
5258
* }}}
5359
*/
5460
@Experimental
55-
sealed abstract class StateSpec[K, V, S, T] extends Serializable {
61+
sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
5662

5763
/** Set the RDD containing the initial states that will be used by `trackStateByKey`*/
58-
def initialState(rdd: RDD[(K, S)]): this.type
64+
def initialState(rdd: RDD[(KeyType, StateType)]): this.type
5965

6066
/** Set the RDD containing the initial states that will be used by `trackStateByKey`*/
61-
def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type
67+
def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
6268

6369
/**
6470
* Set the number of partitions by which the state RDDs generated by `trackStateByKey`
@@ -93,15 +99,20 @@ sealed abstract class StateSpec[K, V, S, T] extends Serializable {
9399
*
94100
* Example in Scala:
95101
* {{{
96-
* val spec = StateSpec(trackingFunction).numPartitions(10)
102+
* def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
103+
* ...
104+
* }
105+
*
106+
* val spec = StateSpec.function(trackingFunction).numPartitions(10)
97107
*
98108
* val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
99109
* }}}
100110
*
101111
* Example in Java:
102112
* {{{
103-
* StateStateSpec[StateType, EmittedDataType] spec =
104-
* StateStateSpec.create[StateType, EmittedDataType](trackingFunction).numPartition(10);
113+
* StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
114+
* StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
115+
* .numPartition(10);
105116
*
106117
* JavaDStream[EmittedDataType] emittedRecordDStream =
107118
* javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
@@ -115,16 +126,17 @@ object StateSpec {
115126
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
116127
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
117128
* @param trackingFunction The function applied on every data item to manage the associated state
118-
* and generate the emitted data and
129+
* and generate the emitted data
119130
* @tparam KeyType Class of the keys
120131
* @tparam ValueType Class of the values
121132
* @tparam StateType Class of the states data
122133
* @tparam EmittedType Class of the emitted data
123134
*/
124-
def apply[KeyType, ValueType, StateType, EmittedType](
125-
trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
135+
def function[KeyType, ValueType, StateType, EmittedType](
136+
trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
126137
): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
127-
new StateSpecImpl[KeyType, ValueType, StateType, EmittedType](trackingFunction)
138+
ClosureCleaner.clean(trackingFunction, checkSerializable = true)
139+
new StateSpecImpl(trackingFunction)
128140
}
129141

130142
/**
@@ -133,24 +145,28 @@ object StateSpec {
133145
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
134146
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
135147
* @param trackingFunction The function applied on every data item to manage the associated state
136-
* and generate the emitted data and
137-
* @tparam KeyType Class of the keys
148+
* and generate the emitted data
138149
* @tparam ValueType Class of the values
139150
* @tparam StateType Class of the states data
140151
* @tparam EmittedType Class of the emitted data
141152
*/
142-
def create[KeyType, ValueType, StateType, EmittedType](
143-
trackingFunction: (KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
144-
): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
145-
apply(trackingFunction)
153+
def function[ValueType, StateType, EmittedType](
154+
trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
155+
): StateSpec[Any, ValueType, StateType, EmittedType] = {
156+
ClosureCleaner.clean(trackingFunction, checkSerializable = true)
157+
val wrappedFunction =
158+
(time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
159+
Some(trackingFunction(value, state))
160+
}
161+
new StateSpecImpl[Any, ValueType, StateType, EmittedType](wrappedFunction)
146162
}
147163
}
148164

149165

150166
/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */
151167
private[streaming]
152168
case class StateSpecImpl[K, V, S, T](
153-
function: (K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] {
169+
function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] {
154170

155171
require(function != null)
156172

@@ -186,7 +202,7 @@ case class StateSpecImpl[K, V, S, T](
186202

187203
// ================= Private Methods =================
188204

189-
private[streaming] def getFunction(): (K, Option[V], State[S]) => Option[T] = function
205+
private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function
190206

191207
private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD)
192208

streaming/src/main/scala/org/apache/spark/streaming/dstream/EmittedRecordsDStream.scala

Lines changed: 0 additions & 115 deletions
This file was deleted.

0 commit comments

Comments
 (0)