Skip to content

Commit cf24fad

Browse files
committed
Abort StateStore on error
1 parent e635cbb commit cf24fad

File tree

5 files changed

+86
-6
lines changed

5 files changed

+86
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.state._
3131
import org.apache.spark.sql.execution.SparkPlan
3232
import org.apache.spark.sql.streaming.OutputMode
3333
import org.apache.spark.sql.types.StructType
34+
import org.apache.spark.TaskContext
3435

3536

3637
/** Used to identify the state store for a given operator. */
@@ -150,6 +151,13 @@ case class StateStoreSaveExec(
150151
val numTotalStateRows = longMetric("numTotalStateRows")
151152
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
152153

154+
// Abort the state store in case of error
155+
TaskContext.get().addTaskCompletionListener(_ => {
156+
if (!store.hasCommitted) {
157+
store.abort()
158+
}
159+
})
160+
153161
outputMode match {
154162
// Update and output all rows in the StateStore.
155163
case Some(Complete) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ private[state] class HDFSBackedStateStoreProvider(
203203
/**
204204
* Whether all updates have been committed
205205
*/
206-
override private[state] def hasCommitted: Boolean = {
206+
override private[streaming] def hasCommitted: Boolean = {
207207
state == COMMITTED
208208
}
209209

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ trait StateStore {
8383
/**
8484
* Whether all updates have been committed
8585
*/
86-
private[state] def hasCommitted: Boolean
86+
private[streaming] def hasCommitted: Boolean
8787
}
8888

8989

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
2020
import org.apache.spark.sql.internal.SQLConf
2121

2222
/** A class that contains configuration parameters for [[StateStore]]s. */
23-
private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
23+
private[sql] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
2424

2525
def this() = this(new SQLConf)
2626

@@ -29,7 +29,7 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex
2929
val minVersionsToRetain = conf.minBatchesToRetain
3030
}
3131

32-
private[streaming] object StateStoreConf {
32+
private[sql] object StateStoreConf {
3333
val empty = new StateStoreConf()
3434

3535
def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf)

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,26 @@ package org.apache.spark.sql.streaming
1919

2020
import java.util.TimeZone
2121

22+
import scala.collection.mutable
23+
import scala.reflect.runtime.{universe => ru}
24+
25+
import org.apache.hadoop.conf.Configuration
26+
import org.mockito.Mockito
27+
import org.mockito.invocation.InvocationOnMock
28+
import org.mockito.stubbing.Answer
2229
import org.scalatest.BeforeAndAfterAll
30+
import org.scalatest.PrivateMethodTester._
2331

2432
import org.apache.spark.SparkException
2533
import org.apache.spark.sql.AnalysisException
26-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
34+
import org.apache.spark.sql.catalyst.util._
2735
import org.apache.spark.sql.execution.SparkPlan
2836
import org.apache.spark.sql.execution.streaming._
29-
import org.apache.spark.sql.execution.streaming.state.StateStore
37+
import org.apache.spark.sql.execution.streaming.state._
3038
import org.apache.spark.sql.expressions.scalalang.typed
3139
import org.apache.spark.sql.functions._
3240
import org.apache.spark.sql.streaming.OutputMode._
41+
import org.apache.spark.sql.types._
3342

3443
object FailureSinglton {
3544
var firstTime = true
@@ -335,4 +344,67 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
335344
CheckLastBatch((90L, 1), (100L, 1), (105L, 1))
336345
)
337346
}
347+
348+
test("abort StateStore in case of error") {
349+
quietly {
350+
val inputData = MemoryStream[Long]
351+
val aggregated =
352+
inputData.toDS()
353+
.groupBy($"value")
354+
.agg(count("*"))
355+
var aborted = false
356+
testStream(aggregated, Complete)(
357+
// This whole `AssertOnQuery` is used to inject a mock state store
358+
AssertOnQuery(execution => {
359+
// (1) Use reflection to get `StateStore.loadedProviders`
360+
val loadedProviders = {
361+
val field = ru.typeOf[StateStore.type].decl(ru.TermName("loadedProviders")).asTerm
362+
ru.runtimeMirror(StateStore.getClass.getClassLoader)
363+
.reflect(StateStore)
364+
.reflectField(field)
365+
.get
366+
.asInstanceOf[mutable.HashMap[StateStoreId, StateStoreProvider]]
367+
}
368+
// (2) Make a storeId
369+
val storeId = {
370+
val checkpointLocation =
371+
execution invokePrivate PrivateMethod[String]('checkpointFile)("state")
372+
StateStoreId(checkpointLocation, 0L, 0)
373+
}
374+
// (3) Make `mockStore` and `mockProvider`
375+
val (mockStore, mockProvider) = {
376+
val keySchema = StructType(Seq(
377+
StructField("value", LongType, false)))
378+
val valueSchema = StructType(Seq(
379+
StructField("value", LongType, false), StructField("count", LongType, false)))
380+
val storeConf = StateStoreConf.empty
381+
val hadoopConf = new Configuration
382+
(Mockito.spy(
383+
StateStore.get(storeId, keySchema, valueSchema, version = 0, storeConf, hadoopConf)),
384+
Mockito.spy(loadedProviders.get(storeId).get))
385+
}
386+
// (4) Setup `mockStore` and `mockProvider`
387+
Mockito.doAnswer(new Answer[Long] {
388+
override def answer(invocationOnMock: InvocationOnMock): Long = {
389+
sys.error("injected error on commit()")
390+
}
391+
}).when(mockStore).commit()
392+
Mockito.doAnswer(new Answer[Unit] {
393+
override def answer(invocationOnMock: InvocationOnMock): Unit = {
394+
invocationOnMock.callRealMethod()
395+
// Mark the flag for later check
396+
aborted = true
397+
}
398+
}).when(mockStore).abort()
399+
Mockito.doReturn(mockStore).when(mockProvider).getStore(version = 0)
400+
// (5) Inject `mockProvider`, which later on would inject `mockStore`
401+
loadedProviders.put(storeId, mockProvider)
402+
true
403+
}), // End of AssertOnQuery, i.e. end of injecting `mockStore`
404+
AddData(inputData, 1L, 2L, 3L),
405+
ExpectFailure[SparkException](),
406+
AssertOnQuery { _ => aborted } // Check that `mockStore.abort()` is called upon error
407+
)
408+
}
409+
}
338410
}

0 commit comments

Comments
 (0)