Skip to content

Commit 2f8613f

Browse files
committed
[SPARK-38684][SS] Fix correctness issue on stream-stream outer join with RocksDB state store provider
### What changes were proposed in this pull request? (Credit to alex-balikov for the inspiration of the root cause observation, and anishshri-db for looking into the issue together.) This PR fixes the correctness issue on stream-stream outer join with RocksDB state store provider, which can occur in certain condition, like below: * stream-stream time interval outer join * left outer join has an issue on left side, right outer join has an issue on right side, full outer join has an issue on both sides * At batch N, produce non-late row(s) on the problematic side * At the same batch (batch N), some row(s) on the problematic side are evicted by the condition of watermark The root cause is same as [SPARK-38320](https://issues.apache.org/jira/browse/SPARK-38320) - weak read consistency on iterator, especially with RocksDB state store provider. (Quoting from SPARK-38320: The problem is due to the StateStore.iterator not reflecting StateStore changes made after its creation.) More specifically, if updates are performed during processing input rows and somehow updates the number of values for grouping key, the update is not seen in SymmetricHashJoinStateManager.removeByValueCondition, and the method does the eviction with the number of values in out of sync. Making it more worse, if the method performs the eviction and updates the number of values for grouping key, it "overwrites" the number of value, effectively drop all rows being inserted in the same batch. Below code blocks are references on understanding the details of the issue. https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala#L327-L339 https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala#L619-L627 https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala#L195-L201 https://github.com/apache/spark/blob/ca7200b0008dc6101a252020e6c34ef7b72d81d6/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala#L208-L223 This PR fixes the outer iterators as late evaluation to ensure all updates on processing input rows are reflected "before" outer iterators are initialized. ### Why are the changes needed? The bug is described in above section. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT added. Closes #36002 from HeartSaVioR/SPARK-38684. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent eb353aa commit 2f8613f

File tree

2 files changed

+121
-23
lines changed

2 files changed

+121
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -324,17 +324,22 @@ case class StreamingSymmetricHashJoinExec(
324324
}
325325
}
326326

327+
val initIterFn = { () =>
328+
val removedRowIter = leftSideJoiner.removeOldState()
329+
removedRowIter.filterNot { kv =>
330+
stateFormatVersion match {
331+
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
332+
case 2 => kv.matched
333+
case _ => throwBadStateFormatVersionException()
334+
}
335+
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
336+
}
337+
327338
// NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
328-
// elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update
329-
// the match flag which the logic for outer join is relying on.
330-
val removedRowIter = leftSideJoiner.removeOldState()
331-
val outerOutputIter = removedRowIter.filterNot { kv =>
332-
stateFormatVersion match {
333-
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
334-
case 2 => kv.matched
335-
case _ => throwBadStateFormatVersionException()
336-
}
337-
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
339+
// elements in `hashJoinOutputIter`, otherwise it may lead to out of sync according to
340+
// the interface contract on StateStore.iterator and end up with correctness issue.
341+
// Please refer SPARK-38684 for more details.
342+
val outerOutputIter = new LazilyInitializingJoinedRowIterator(initIterFn)
338343

339344
hashJoinOutputIter ++ outerOutputIter
340345
case RightOuter =>
@@ -344,14 +349,23 @@ case class StreamingSymmetricHashJoinExec(
344349
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
345350
}
346351
}
347-
val removedRowIter = rightSideJoiner.removeOldState()
348-
val outerOutputIter = removedRowIter.filterNot { kv =>
349-
stateFormatVersion match {
350-
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
351-
case 2 => kv.matched
352-
case _ => throwBadStateFormatVersionException()
353-
}
354-
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
352+
353+
val initIterFn = { () =>
354+
val removedRowIter = rightSideJoiner.removeOldState()
355+
removedRowIter.filterNot { kv =>
356+
stateFormatVersion match {
357+
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
358+
case 2 => kv.matched
359+
case _ => throwBadStateFormatVersionException()
360+
}
361+
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
362+
}
363+
364+
// NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
365+
// elements in `hashJoinOutputIter`, otherwise it may lead to out of sync according to
366+
// the interface contract on StateStore.iterator and end up with correctness issue.
367+
// Please refer SPARK-38684 for more details.
368+
val outerOutputIter = new LazilyInitializingJoinedRowIterator(initIterFn)
355369

356370
hashJoinOutputIter ++ outerOutputIter
357371
case FullOuter =>
@@ -360,10 +374,25 @@ case class StreamingSymmetricHashJoinExec(
360374
case 2 => kv.matched
361375
case _ => throwBadStateFormatVersionException()
362376
}
363-
val leftSideOutputIter = leftSideJoiner.removeOldState().filterNot(
364-
isKeyToValuePairMatched).map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
365-
val rightSideOutputIter = rightSideJoiner.removeOldState().filterNot(
366-
isKeyToValuePairMatched).map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
377+
378+
val leftSideInitIterFn = { () =>
379+
val removedRowIter = leftSideJoiner.removeOldState()
380+
removedRowIter.filterNot(isKeyToValuePairMatched)
381+
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
382+
}
383+
384+
val rightSideInitIterFn = { () =>
385+
val removedRowIter = rightSideJoiner.removeOldState()
386+
removedRowIter.filterNot(isKeyToValuePairMatched)
387+
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
388+
}
389+
390+
// NOTE: we need to make sure both `leftSideOutputIter` and `rightSideOutputIter` are
391+
// evaluated "after" exhausting all of elements in `hashJoinOutputIter`, otherwise it may
392+
// lead to out of sync according to the interface contract on StateStore.iterator and
393+
// end up with correctness issue. Please refer SPARK-38684 for more details.
394+
val leftSideOutputIter = new LazilyInitializingJoinedRowIterator(leftSideInitIterFn)
395+
val rightSideOutputIter = new LazilyInitializingJoinedRowIterator(rightSideInitIterFn)
367396

368397
hashJoinOutputIter ++ leftSideOutputIter ++ rightSideOutputIter
369398
case _ => throwBadJoinTypeException()
@@ -638,4 +667,12 @@ case class StreamingSymmetricHashJoinExec(
638667
override protected def withNewChildrenInternal(
639668
newLeft: SparkPlan, newRight: SparkPlan): StreamingSymmetricHashJoinExec =
640669
copy(left = newLeft, right = newRight)
670+
671+
private class LazilyInitializingJoinedRowIterator(
672+
initFn: () => Iterator[JoinedRow]) extends Iterator[JoinedRow] {
673+
private lazy val iter: Iterator[JoinedRow] = initFn()
674+
675+
override def hasNext: Boolean = iter.hasNext
676+
override def next(): JoinedRow = iter.next()
677+
}
641678
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
3333
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
3434
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3535
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
36-
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId}
36+
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
3737
import org.apache.spark.sql.functions._
3838
import org.apache.spark.sql.internal.SQLConf
3939
import org.apache.spark.util.Utils
@@ -1353,6 +1353,67 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite {
13531353
).select(Symbol("leftKey1"), Symbol("rightKey1"), Symbol("leftKey2"), Symbol("rightKey2"),
13541354
$"leftWindow.end".cast("long"), Symbol("leftValue"), Symbol("rightValue"))
13551355
}
1356+
1357+
test("SPARK-38684: outer join works correctly even if processing input rows and " +
1358+
"evicting state rows for same grouping key happens in the same micro-batch") {
1359+
1360+
// The test is to demonstrate the correctness issue in outer join before SPARK-38684.
1361+
withSQLConf(
1362+
SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false",
1363+
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) {
1364+
1365+
val input1 = MemoryStream[(Timestamp, String, String)]
1366+
val df1 = input1.toDF
1367+
.selectExpr("_1 as eventTime", "_2 as id", "_3 as comment")
1368+
.withWatermark("eventTime", "0 second")
1369+
1370+
val input2 = MemoryStream[(Timestamp, String, String)]
1371+
val df2 = input2.toDF
1372+
.selectExpr("_1 as eventTime", "_2 as id", "_3 as comment")
1373+
.withWatermark("eventTime", "0 second")
1374+
1375+
val joined = df1.as("left")
1376+
.join(df2.as("right"),
1377+
expr("""
1378+
|left.id = right.id AND left.eventTime BETWEEN
1379+
| right.eventTime - INTERVAL 30 seconds AND
1380+
| right.eventTime + INTERVAL 30 seconds
1381+
""".stripMargin),
1382+
joinType = "leftOuter")
1383+
1384+
testStream(joined)(
1385+
MultiAddData(
1386+
(input1, Seq((Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "left in batch 1"))),
1387+
(input2, Seq((Timestamp.valueOf("2020-01-02 00:01:00"), "abc", "right in batch 1")))
1388+
),
1389+
CheckNewAnswer(),
1390+
MultiAddData(
1391+
(input1, Seq((Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "left in batch 2"))),
1392+
(input2, Seq((Timestamp.valueOf("2020-01-02 01:01:00"), "abc", "right in batch 2")))
1393+
),
1394+
// watermark advanced to "2020-01-02 00:00:00"
1395+
CheckNewAnswer(),
1396+
AddData(input1, (Timestamp.valueOf("2020-01-02 01:30:00"), "abc", "left in batch 3")),
1397+
// watermark advanced to "2020-01-02 01:00:00"
1398+
CheckNewAnswer(
1399+
(Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "left in batch 1", null, null, null)
1400+
),
1401+
// left side state should still contain "left in batch 2" and "left in batch 3"
1402+
// we should see both rows in the left side since
1403+
// - "left in batch 2" is going to be evicted in this batch
1404+
// - "left in batch 3" is going to be matched with new row in right side
1405+
AddData(input2,
1406+
(Timestamp.valueOf("2020-01-02 01:30:10"), "abc", "match with left in batch 3")),
1407+
// watermark advanced to "2020-01-02 01:01:00"
1408+
CheckNewAnswer(
1409+
(Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "left in batch 2",
1410+
null, null, null),
1411+
(Timestamp.valueOf("2020-01-02 01:30:00"), "abc", "left in batch 3",
1412+
Timestamp.valueOf("2020-01-02 01:30:10"), "abc", "match with left in batch 3")
1413+
)
1414+
)
1415+
}
1416+
}
13561417
}
13571418

13581419
class StreamingFullOuterJoinSuite extends StreamingJoinSuite {

0 commit comments

Comments
 (0)