Skip to content

Commit 04905dd

Browse files
committed
Okay this might be the way to go
1 parent ac3648e commit 04905dd

File tree

3 files changed

+101
-15
lines changed

3 files changed

+101
-15
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
2020
import java.io._
2121
import java.util.UUID
2222
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
23+
import java.util.concurrent.atomic.AtomicBoolean
2324

2425
import scala.util.control.NonFatal
2526

@@ -58,14 +59,14 @@ private[sql] class RocksDBStateStoreProvider
5859

5960
@volatile private var state: STATE = UPDATING
6061

61-
@volatile private var usedToCreateWriteStore: Boolean = false
62+
private val usedToCreateWriteStore: AtomicBoolean = new AtomicBoolean(false)
6263

6364
override def getReadStamp: Long = {
64-
usedToCreateWriteStore = true
65+
usedToCreateWriteStore.set(true)
6566
stamp
6667
}
6768

68-
override def usedForWriteStore: Boolean = usedToCreateWriteStore
69+
override def usedForWriteStore: Boolean = usedToCreateWriteStore.get()
6970

7071
/**
7172
* Validates the expected state, throws exception if state is not as expected.
@@ -132,7 +133,9 @@ private[sql] class RocksDBStateStoreProvider
132133
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] {
133134
_ =>
134135
try {
135-
abort()
136+
if (state == UPDATING) {
137+
abort()
138+
}
136139
} catch {
137140
case NonFatal(e) =>
138141
logWarning("Failed to abort state store", e)
@@ -337,7 +340,7 @@ private[sql] class RocksDBStateStoreProvider
337340
private var storedMetrics: Option[RocksDBMetrics] = None
338341

339342
override def commit(): Long = synchronized {
340-
if (usedToCreateWriteStore) {
343+
if (usedToCreateWriteStore.get()) {
341344
return -1
342345
}
343346
validateState(List(UPDATING))
@@ -359,7 +362,7 @@ private[sql] class RocksDBStateStoreProvider
359362
}
360363

361364
override def abort(): Unit = {
362-
if (usedToCreateWriteStore) {
365+
if (usedToCreateWriteStore.get()) {
363366
return
364367
}
365368
if (validateState(List(UPDATING, ABORTED)) != ABORTED) {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object StateStoreEncoding {
6666
* not supported yet from the implementation. Note that some stateful operations would not work
6767
* on disabling prefixScan functionality.
6868
*/
69-
trait ReadStateStore {
69+
trait ReadStateStore extends Logging {
7070

7171
/** Unique identifier of the store */
7272
def id: StateStoreId
@@ -937,6 +937,30 @@ object StateStore extends Logging {
937937
stateSchemaBroadcast)
938938
storeProvider.getStore(version, stateStoreCkptId)
939939
}
940+
941+
def getWriteStore(
942+
readStore: ReadStateStore,
943+
storeProviderId: StateStoreProviderId,
944+
keySchema: StructType,
945+
valueSchema: StructType,
946+
keyStateEncoderSpec: KeyStateEncoderSpec,
947+
version: Long,
948+
stateStoreCkptId: Option[String],
949+
stateSchemaBroadcast: Option[StateSchemaBroadcast],
950+
useColumnFamilies: Boolean,
951+
storeConf: StateStoreConf,
952+
hadoopConf: Configuration,
953+
useMultipleValuesPerKey: Boolean = false): StateStore = {
954+
hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString)
955+
if (version < 0) {
956+
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
957+
}
958+
hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString)
959+
val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
960+
keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey,
961+
stateSchemaBroadcast)
962+
storeProvider.getWriteStore(readStore, version, stateStoreCkptId)
963+
}
940964
// scalastyle:on
941965

942966
private def getStateStoreProvider(

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ import org.apache.spark.sql.internal.SessionState
2727
import org.apache.spark.sql.types.StructType
2828
import org.apache.spark.util.SerializableConfiguration
2929

30+
trait StateStoreRDDProvider {
31+
def getStateStoreForPartition(partitionId: Int): Option[ReadStateStore]
32+
}
33+
3034
abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag](
3135
dataRDD: RDD[T],
3236
checkpointLocation: String,
@@ -82,7 +86,17 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
8286
useColumnFamilies: Boolean = false,
8387
extraOptions: Map[String, String] = Map.empty)
8488
extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId,
85-
sessionState, storeCoordinator, extraOptions) {
89+
sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider {
90+
91+
// ThreadLocal to store state stores by partition ID
92+
@transient private lazy val partitionStores =
93+
new ThreadLocal[Map[Int, ReadStateStore]]() {
94+
override def initialValue(): Map[Int, ReadStateStore] = Map.empty
95+
}
96+
97+
override def getStateStoreForPartition(partitionId: Int): Option[ReadStateStore] = {
98+
Option(partitionStores.get()).flatMap(_.get(partitionId))
99+
}
86100

87101
override protected def getPartitions: Array[Partition] = dataRDD.partitions
88102

@@ -95,6 +109,8 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
95109
stateStoreCkptIds.map(_.apply(partition.index).head),
96110
stateSchemaBroadcast,
97111
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value)
112+
// Store reference for this partition
113+
partitionStores.set(partitionStores.get() + (partition.index -> store))
98114
storeReadFunction(store, inputIter)
99115
}
100116
}
@@ -126,16 +142,59 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
126142

127143
override protected def getPartitions: Array[Partition] = dataRDD.partitions
128144

145+
// Recursively find a state store provider in the RDD lineage
146+
private def findStateStoreProvider(rdd: RDD[_]): Option[StateStoreRDDProvider] = {
147+
rdd match {
148+
case provider: StateStoreRDDProvider => Some(provider)
149+
case _ if rdd.dependencies.isEmpty => None
150+
case _ =>
151+
// Search all dependencies
152+
rdd.dependencies.view
153+
.map(dep => findStateStoreProvider(dep.rdd))
154+
.find(_.isDefined)
155+
.flatten
156+
}
157+
}
158+
129159
override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
130160
val storeProviderId = getStateProviderId(partition)
131-
132161
val inputIter = dataRDD.iterator(partition, ctxt)
133-
val store = StateStore.get(
134-
storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
135-
uniqueId.map(_.apply(partition.index).head),
136-
stateSchemaBroadcast,
137-
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
138-
useMultipleValuesPerKey)
162+
163+
// Try to find a state store provider in the RDD lineage
164+
val store = findStateStoreProvider(dataRDD).flatMap { provider =>
165+
provider.getStateStoreForPartition(partition.index)
166+
} match {
167+
case Some(readStore) =>
168+
// Convert the read store to a writable store
169+
StateStore.getWriteStore(
170+
readStore,
171+
storeProviderId,
172+
keySchema,
173+
valueSchema,
174+
keyStateEncoderSpec,
175+
storeVersion,
176+
uniqueId.map(_.apply(partition.index).head),
177+
stateSchemaBroadcast,
178+
useColumnFamilies,
179+
storeConf,
180+
hadoopConfBroadcast.value.value,
181+
useMultipleValuesPerKey)
182+
183+
case None =>
184+
// Fall back to creating a new store
185+
StateStore.get(
186+
storeProviderId,
187+
keySchema,
188+
valueSchema,
189+
keyStateEncoderSpec,
190+
storeVersion,
191+
uniqueId.map(_.apply(partition.index).head),
192+
stateSchemaBroadcast,
193+
useColumnFamilies,
194+
storeConf,
195+
hadoopConfBroadcast.value.value,
196+
useMultipleValuesPerKey)
197+
}
139198
storeUpdateFunction(store, inputIter)
140199
}
141200
}

0 commit comments

Comments
 (0)