@@ -27,6 +27,10 @@ import org.apache.spark.sql.internal.SessionState
27
27
import org .apache .spark .sql .types .StructType
28
28
import org .apache .spark .util .SerializableConfiguration
29
29
30
+ trait StateStoreRDDProvider {
31
+ def getStateStoreForPartition (partitionId : Int ): Option [ReadStateStore ]
32
+ }
33
+
30
34
abstract class BaseStateStoreRDD [T : ClassTag , U : ClassTag ](
31
35
dataRDD : RDD [T ],
32
36
checkpointLocation : String ,
@@ -82,7 +86,17 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
82
86
useColumnFamilies : Boolean = false ,
83
87
extraOptions : Map [String , String ] = Map .empty)
84
88
extends BaseStateStoreRDD [T , U ](dataRDD, checkpointLocation, queryRunId, operatorId,
85
- sessionState, storeCoordinator, extraOptions) {
89
+ sessionState, storeCoordinator, extraOptions) with StateStoreRDDProvider {
90
+
91
+ // ThreadLocal to store state stores by partition ID
92
+ @ transient private lazy val partitionStores =
93
+ new ThreadLocal [Map [Int , ReadStateStore ]]() {
94
+ override def initialValue (): Map [Int , ReadStateStore ] = Map .empty
95
+ }
96
+
97
+ override def getStateStoreForPartition (partitionId : Int ): Option [ReadStateStore ] = {
98
+ Option (partitionStores.get()).flatMap(_.get(partitionId))
99
+ }
86
100
87
101
override protected def getPartitions : Array [Partition ] = dataRDD.partitions
88
102
@@ -95,6 +109,8 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
95
109
stateStoreCkptIds.map(_.apply(partition.index).head),
96
110
stateSchemaBroadcast,
97
111
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value)
112
+ // Store reference for this partition
113
+ partitionStores.set(partitionStores.get() + (partition.index -> store))
98
114
storeReadFunction(store, inputIter)
99
115
}
100
116
}
@@ -126,16 +142,59 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
126
142
127
143
override protected def getPartitions : Array [Partition ] = dataRDD.partitions
128
144
145
+ // Recursively find a state store provider in the RDD lineage
146
+ private def findStateStoreProvider (rdd : RDD [_]): Option [StateStoreRDDProvider ] = {
147
+ rdd match {
148
+ case provider : StateStoreRDDProvider => Some (provider)
149
+ case _ if rdd.dependencies.isEmpty => None
150
+ case _ =>
151
+ // Search all dependencies
152
+ rdd.dependencies.view
153
+ .map(dep => findStateStoreProvider(dep.rdd))
154
+ .find(_.isDefined)
155
+ .flatten
156
+ }
157
+ }
158
+
129
159
override def compute (partition : Partition , ctxt : TaskContext ): Iterator [U ] = {
130
160
val storeProviderId = getStateProviderId(partition)
131
-
132
161
val inputIter = dataRDD.iterator(partition, ctxt)
133
- val store = StateStore .get(
134
- storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
135
- uniqueId.map(_.apply(partition.index).head),
136
- stateSchemaBroadcast,
137
- useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
138
- useMultipleValuesPerKey)
162
+
163
+ // Try to find a state store provider in the RDD lineage
164
+ val store = findStateStoreProvider(dataRDD).flatMap { provider =>
165
+ provider.getStateStoreForPartition(partition.index)
166
+ } match {
167
+ case Some (readStore) =>
168
+ // Convert the read store to a writable store
169
+ StateStore .getWriteStore(
170
+ readStore,
171
+ storeProviderId,
172
+ keySchema,
173
+ valueSchema,
174
+ keyStateEncoderSpec,
175
+ storeVersion,
176
+ uniqueId.map(_.apply(partition.index).head),
177
+ stateSchemaBroadcast,
178
+ useColumnFamilies,
179
+ storeConf,
180
+ hadoopConfBroadcast.value.value,
181
+ useMultipleValuesPerKey)
182
+
183
+ case None =>
184
+ // Fall back to creating a new store
185
+ StateStore .get(
186
+ storeProviderId,
187
+ keySchema,
188
+ valueSchema,
189
+ keyStateEncoderSpec,
190
+ storeVersion,
191
+ uniqueId.map(_.apply(partition.index).head),
192
+ stateSchemaBroadcast,
193
+ useColumnFamilies,
194
+ storeConf,
195
+ hadoopConfBroadcast.value.value,
196
+ useMultipleValuesPerKey)
197
+ }
139
198
storeUpdateFunction(store, inputIter)
140
199
}
141
200
}
0 commit comments