Skip to content

Commit d6cd20f

Browse files
committed
moving getListState/getMapState methods
1 parent dd07c52 commit d6cd20f

File tree

1 file changed

+74
-74
lines changed

1 file changed

+74
-74
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,80 @@ class StatefulProcessorHandleImpl(
160160
valueStateWithTTL
161161
}
162162

163+
override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
164+
verifyStateVarOperations("get_list_state")
165+
incrementMetric("numListStateVars")
166+
val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder)
167+
stateVariables.add(new StateVariableInfo(stateName, ListState, false))
168+
columnFamilyMetadatas.add(resultState.columnFamilyMetadata)
169+
resultState
170+
}
171+
172+
/**
173+
* Function to create new or return existing list state variable of given type
174+
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
175+
* from the state store. Any values in listState which have expired after ttlDuration will not
176+
* returned on get() and will be eventually removed from the state.
177+
*
178+
* The user must ensure to call this function only within the `init()` method of the
179+
* StatefulProcessor.
180+
*
181+
* @param stateName - name of the state variable
182+
* @param valEncoder - SQL encoder for state variable
183+
* @param ttlConfig - the ttl configuration (time to live duration etc.)
184+
* @tparam T - type of state variable
185+
* @return - instance of ListState of type T that can be used to store state persistently
186+
*/
187+
override def getListState[T](
188+
stateName: String,
189+
valEncoder: Encoder[T],
190+
ttlConfig: TTLConfig): ListState[T] = {
191+
192+
verifyStateVarOperations("get_list_state")
193+
validateTTLConfig(ttlConfig, stateName)
194+
195+
assert(batchTimestampMs.isDefined)
196+
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
197+
keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
198+
incrementMetric("numListStateWithTTLVars")
199+
ttlStates.add(listStateWithTTL)
200+
stateVariables.add(new StateVariableInfo(stateName, ListState, true))
201+
columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata)
202+
203+
listStateWithTTL
204+
}
205+
206+
override def getMapState[K, V](
207+
stateName: String,
208+
userKeyEnc: Encoder[K],
209+
valEncoder: Encoder[V]): MapState[K, V] = {
210+
verifyStateVarOperations("get_map_state")
211+
incrementMetric("numMapStateVars")
212+
val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder)
213+
stateVariables.add(new StateVariableInfo(stateName, MapState, false))
214+
columnFamilyMetadatas.add(resultState.columnFamilyMetadata)
215+
resultState
216+
}
217+
218+
override def getMapState[K, V](
219+
stateName: String,
220+
userKeyEnc: Encoder[K],
221+
valEncoder: Encoder[V],
222+
ttlConfig: TTLConfig): MapState[K, V] = {
223+
verifyStateVarOperations("get_map_state")
224+
validateTTLConfig(ttlConfig, stateName)
225+
226+
assert(batchTimestampMs.isDefined)
227+
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
228+
valEncoder, ttlConfig, batchTimestampMs.get)
229+
incrementMetric("numMapStateWithTTLVars")
230+
ttlStates.add(mapStateWithTTL)
231+
stateVariables.add(new StateVariableInfo(stateName, MapState, true))
232+
columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata)
233+
234+
mapStateWithTTL
235+
}
236+
163237
override def getQueryInfo(): QueryInfo = currQueryInfo
164238

165239
private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder)
@@ -250,80 +324,6 @@ class StatefulProcessorHandleImpl(
250324
}
251325
}
252326

253-
override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
254-
verifyStateVarOperations("get_list_state")
255-
incrementMetric("numListStateVars")
256-
val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder)
257-
stateVariables.add(new StateVariableInfo(stateName, ListState, false))
258-
columnFamilyMetadatas.add(resultState.columnFamilyMetadata)
259-
resultState
260-
}
261-
262-
/**
263-
* Function to create new or return existing list state variable of given type
264-
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
265-
* from the state store. Any values in listState which have expired after ttlDuration will not
266-
* returned on get() and will be eventually removed from the state.
267-
*
268-
* The user must ensure to call this function only within the `init()` method of the
269-
* StatefulProcessor.
270-
*
271-
* @param stateName - name of the state variable
272-
* @param valEncoder - SQL encoder for state variable
273-
* @param ttlConfig - the ttl configuration (time to live duration etc.)
274-
* @tparam T - type of state variable
275-
* @return - instance of ListState of type T that can be used to store state persistently
276-
*/
277-
override def getListState[T](
278-
stateName: String,
279-
valEncoder: Encoder[T],
280-
ttlConfig: TTLConfig): ListState[T] = {
281-
282-
verifyStateVarOperations("get_list_state")
283-
validateTTLConfig(ttlConfig, stateName)
284-
285-
assert(batchTimestampMs.isDefined)
286-
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
287-
keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
288-
incrementMetric("numListStateWithTTLVars")
289-
ttlStates.add(listStateWithTTL)
290-
stateVariables.add(new StateVariableInfo(stateName, ListState, true))
291-
columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata)
292-
293-
listStateWithTTL
294-
}
295-
296-
override def getMapState[K, V](
297-
stateName: String,
298-
userKeyEnc: Encoder[K],
299-
valEncoder: Encoder[V]): MapState[K, V] = {
300-
verifyStateVarOperations("get_map_state")
301-
incrementMetric("numMapStateVars")
302-
val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder)
303-
stateVariables.add(new StateVariableInfo(stateName, MapState, false))
304-
columnFamilyMetadatas.add(resultState.columnFamilyMetadata)
305-
resultState
306-
}
307-
308-
override def getMapState[K, V](
309-
stateName: String,
310-
userKeyEnc: Encoder[K],
311-
valEncoder: Encoder[V],
312-
ttlConfig: TTLConfig): MapState[K, V] = {
313-
verifyStateVarOperations("get_map_state")
314-
validateTTLConfig(ttlConfig, stateName)
315-
316-
assert(batchTimestampMs.isDefined)
317-
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
318-
valEncoder, ttlConfig, batchTimestampMs.get)
319-
incrementMetric("numMapStateWithTTLVars")
320-
ttlStates.add(mapStateWithTTL)
321-
stateVariables.add(new StateVariableInfo(stateName, MapState, true))
322-
columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata)
323-
324-
mapStateWithTTL
325-
}
326-
327327
private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit = {
328328
val ttlDuration = ttlConfig.ttlDuration
329329
if (timeMode != TimeMode.ProcessingTime()) {

0 commit comments

Comments
 (0)