Skip to content

Commit a94c3ca

Browse files
committed
Merge branch 'SPARK-21414' into 'spark_2.1'
[SPARK-21414] Refine SlidingWindowFunctionFrame to avoid OOM Refine SlidingWindowFunctionFrame to avoid OOM resolve apache#66 See merge request !59
2 parents b93e102 + ec26f51 commit a94c3ca

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,6 @@ private[window] final class SlidingWindowFunctionFrame(
195195
override def write(index: Int, current: InternalRow): Unit = {
196196
var bufferUpdated = index == 0
197197

198-
// Add all rows to the buffer for which the input row value is equal to or less than
199-
// the output row upper bound.
200-
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
201-
buffer.add(nextRow.copy())
202-
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
203-
inputHighIndex += 1
204-
bufferUpdated = true
205-
}
206-
207198
// Drop all rows from the buffer for which the input row value is smaller than
208199
// the output row lower bound.
209200
while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
@@ -212,6 +203,19 @@ private[window] final class SlidingWindowFunctionFrame(
212203
bufferUpdated = true
213204
}
214205

206+
// Add all rows to the buffer for which the input row value is equal to or less than
207+
// the output row upper bound.
208+
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
209+
if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) {
210+
inputLowIndex += 1
211+
} else {
212+
buffer.add(nextRow.copy())
213+
bufferUpdated = true
214+
}
215+
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
216+
inputHighIndex += 1
217+
}
218+
215219
// Only recalculate and update when the buffer changes.
216220
if (bufferUpdated) {
217221
processor.initialize(input.length)

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,46 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
356356
spark.catalog.dropTempView("nums")
357357
}
358358

359+
test("window function: mutiple window expressions specified by range in a single expression") {
360+
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
361+
nums.createOrReplaceTempView("nums")
362+
withTempView("nums") {
363+
val expected =
364+
Row(1, 1, 1, 4, null, 8, 25) ::
365+
Row(1, 3, 4, 9, 1, 12, 24) ::
366+
Row(1, 5, 9, 15, 4, 16, 21) ::
367+
Row(1, 7, 16, 21, 8, 9, 16) ::
368+
Row(1, 9, 25, 16, 12, null, 9) ::
369+
Row(0, 2, 2, 6, null, 10, 30) ::
370+
Row(0, 4, 6, 12, 2, 14, 28) ::
371+
Row(0, 6, 12, 18, 6, 18, 24) ::
372+
Row(0, 8, 20, 24, 10, 10, 18) ::
373+
Row(0, 10, 30, 18, 14, null, 10) ::
374+
Nil
375+
376+
val actual = sql(
377+
"""
378+
|SELECT
379+
| y,
380+
| x,
381+
| sum(x) over w1 as history_sum,
382+
| sum(x) over w2 as period_sum1,
383+
| sum(x) over w3 as period_sum2,
384+
| sum(x) over w4 as period_sum3,
385+
| sum(x) over w5 as future_sum
386+
|FROM nums
387+
|WINDOW
388+
| w1 AS (PARTITION BY y ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW),
389+
| w2 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING),
390+
| w3 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING ),
391+
| w4 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 4 FOLLOWING),
392+
| w5 AS (PARTITION BY y ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
393+
""".stripMargin
394+
)
395+
checkAnswer(actual, expected)
396+
}
397+
}
398+
359399
test("SPARK-7595: Window will cause resolve failed with self join") {
360400
checkAnswer(sql(
361401
"""

0 commit comments

Comments
 (0)