@@ -324,17 +324,22 @@ case class StreamingSymmetricHashJoinExec(
324
324
}
325
325
}
326
326
327
+ val initIterFn = { () =>
328
+ val removedRowIter = leftSideJoiner.removeOldState()
329
+ removedRowIter.filterNot { kv =>
330
+ stateFormatVersion match {
331
+ case 1 => matchesWithRightSideState(new UnsafeRowPair (kv.key, kv.value))
332
+ case 2 => kv.matched
333
+ case _ => throwBadStateFormatVersionException()
334
+ }
335
+ }.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
336
+ }
337
+
327
338
// NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
328
- // elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update
329
- // the match flag which the logic for outer join is relying on.
330
- val removedRowIter = leftSideJoiner.removeOldState()
331
- val outerOutputIter = removedRowIter.filterNot { kv =>
332
- stateFormatVersion match {
333
- case 1 => matchesWithRightSideState(new UnsafeRowPair (kv.key, kv.value))
334
- case 2 => kv.matched
335
- case _ => throwBadStateFormatVersionException()
336
- }
337
- }.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
339
+ // elements in `hashJoinOutputIter`, otherwise it may lead to out of sync according to
340
+ // the interface contract on StateStore.iterator and end up with correctness issue.
341
+ // Please refer SPARK-38684 for more details.
342
+ val outerOutputIter = new LazilyInitializingJoinedRowIterator (initIterFn)
338
343
339
344
hashJoinOutputIter ++ outerOutputIter
340
345
case RightOuter =>
@@ -344,14 +349,23 @@ case class StreamingSymmetricHashJoinExec(
344
349
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
345
350
}
346
351
}
347
- val removedRowIter = rightSideJoiner.removeOldState()
348
- val outerOutputIter = removedRowIter.filterNot { kv =>
349
- stateFormatVersion match {
350
- case 1 => matchesWithLeftSideState(new UnsafeRowPair (kv.key, kv.value))
351
- case 2 => kv.matched
352
- case _ => throwBadStateFormatVersionException()
353
- }
354
- }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
352
+
353
+ val initIterFn = { () =>
354
+ val removedRowIter = rightSideJoiner.removeOldState()
355
+ removedRowIter.filterNot { kv =>
356
+ stateFormatVersion match {
357
+ case 1 => matchesWithLeftSideState(new UnsafeRowPair (kv.key, kv.value))
358
+ case 2 => kv.matched
359
+ case _ => throwBadStateFormatVersionException()
360
+ }
361
+ }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
362
+ }
363
+
364
+ // NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
365
+ // elements in `hashJoinOutputIter`, otherwise it may lead to out of sync according to
366
+ // the interface contract on StateStore.iterator and end up with correctness issue.
367
+ // Please refer SPARK-38684 for more details.
368
+ val outerOutputIter = new LazilyInitializingJoinedRowIterator (initIterFn)
355
369
356
370
hashJoinOutputIter ++ outerOutputIter
357
371
case FullOuter =>
@@ -360,10 +374,25 @@ case class StreamingSymmetricHashJoinExec(
360
374
case 2 => kv.matched
361
375
case _ => throwBadStateFormatVersionException()
362
376
}
363
- val leftSideOutputIter = leftSideJoiner.removeOldState().filterNot(
364
- isKeyToValuePairMatched).map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
365
- val rightSideOutputIter = rightSideJoiner.removeOldState().filterNot(
366
- isKeyToValuePairMatched).map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
377
+
378
+ val leftSideInitIterFn = { () =>
379
+ val removedRowIter = leftSideJoiner.removeOldState()
380
+ removedRowIter.filterNot(isKeyToValuePairMatched)
381
+ .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
382
+ }
383
+
384
+ val rightSideInitIterFn = { () =>
385
+ val removedRowIter = rightSideJoiner.removeOldState()
386
+ removedRowIter.filterNot(isKeyToValuePairMatched)
387
+ .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
388
+ }
389
+
390
+ // NOTE: we need to make sure both `leftSideOutputIter` and `rightSideOutputIter` are
391
+ // evaluated "after" exhausting all of elements in `hashJoinOutputIter`, otherwise it may
392
+ // lead to out of sync according to the interface contract on StateStore.iterator and
393
+ // end up with correctness issue. Please refer SPARK-38684 for more details.
394
+ val leftSideOutputIter = new LazilyInitializingJoinedRowIterator (leftSideInitIterFn)
395
+ val rightSideOutputIter = new LazilyInitializingJoinedRowIterator (rightSideInitIterFn)
367
396
368
397
hashJoinOutputIter ++ leftSideOutputIter ++ rightSideOutputIter
369
398
case _ => throwBadJoinTypeException()
@@ -638,4 +667,12 @@ case class StreamingSymmetricHashJoinExec(
638
667
override protected def withNewChildrenInternal (
639
668
newLeft : SparkPlan , newRight : SparkPlan ): StreamingSymmetricHashJoinExec =
640
669
copy(left = newLeft, right = newRight)
670
+
671
+ private class LazilyInitializingJoinedRowIterator (
672
+ initFn : () => Iterator [JoinedRow ]) extends Iterator [JoinedRow ] {
673
+ private lazy val iter : Iterator [JoinedRow ] = initFn()
674
+
675
+ override def hasNext : Boolean = iter.hasNext
676
+ override def next (): JoinedRow = iter.next()
677
+ }
641
678
}
0 commit comments