diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index a849c3894f0d6..d49e5ed56626e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -86,7 +86,8 @@ abstract class OffsetWindowFunctionFrameBase( expressions: Array[OffsetWindowFunction], inputSchema: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, - offset: Int) + offset: Int, + ignoreNulls: Boolean) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ @@ -140,6 +141,8 @@ abstract class OffsetWindowFunctionFrameBase( // is not null. protected var skippedNonNullCount = 0 + protected val absOffset = Math.abs(offset) + // Reset the states by the data of the new partition. protected def resetStates(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows @@ -175,6 +178,31 @@ abstract class OffsetWindowFunctionFrameBase( } } + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { + if (absOffset > rows.length) { + fillDefaultValue(EmptyRow) + } else { + resetStates(rows) + if (ignoreNulls) { + prepareForIgnoreNulls() + } else { + prepareForRespectNulls() + } + } + } + + protected def prepareForIgnoreNulls(): Unit = findNextRowWithNonNullInput() + + protected def prepareForRespectNulls(): Unit + + override def write(index: Int, current: InternalRow): Unit = { + if (input != null) { + doWrite(index, current) + } + } + + protected def doWrite(index: Int, current: InternalRow): Unit + override def currentLowerBound(): Int = throw new UnsupportedOperationException() override def currentUpperBound(): Int = throw new UnsupportedOperationException() @@ -196,24 +224,15 @@ class FrameLessOffsetWindowFunctionFrame( offset: Int, ignoreNulls: Boolean = false) extends OffsetWindowFunctionFrameBase( - target, ordinal, expressions, inputSchema, newMutableProjection, offset) { + target, ordinal, expressions, inputSchema, newMutableProjection, offset, ignoreNulls) { - override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { - resetStates(rows) - if (ignoreNulls) { - if (Math.abs(offset) > rows.length) { - fillDefaultValue(EmptyRow) - } else { - findNextRowWithNonNullInput() - } - } else { - // drain the first few rows if offset is larger than zero - while (inputIndex < offset) { - if (inputIterator.hasNext) inputIterator.next() - inputIndex += 1 - } - inputIndex = offset + override def prepareForRespectNulls(): Unit = { + // drain the first few rows if offset is larger than zero + while (inputIndex < offset) { + if (inputIterator.hasNext) inputIterator.next() + inputIndex += 1 } + inputIndex = offset } private val doWrite = if (ignoreNulls && offset > 0) { @@ -260,7 +279,6 @@ class FrameLessOffsetWindowFunctionFrame( // 7. current row -> z, next selected row -> y, output: y; // 8. current row -> v, next selected row -> z, output: z; // 9. current row -> null, next selected row -> v, output: v; - val absOffset = Math.abs(offset) (current: InternalRow) => if (skippedNonNullCount == absOffset) { nextSelectedRow = EmptyRow @@ -294,7 +312,7 @@ class FrameLessOffsetWindowFunctionFrame( inputIndex += 1 } - override def write(index: Int, current: InternalRow): Unit = { + protected def doWrite(index: Int, current: InternalRow): Unit = { doWrite(current) } } @@ -317,35 +335,30 @@ class UnboundedOffsetWindowFunctionFrame( offset: Int, ignoreNulls: Boolean = false) extends OffsetWindowFunctionFrameBase( - target, ordinal, expressions, inputSchema, newMutableProjection, offset) { + target, ordinal, expressions, inputSchema, newMutableProjection, offset, ignoreNulls) { assert(offset > 0) - override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { - if (offset > rows.length) { + override def prepareForIgnoreNulls(): Unit = { + findNextRowWithNonNullInput() + if (nextSelectedRow == EmptyRow) { + // Use default values since the offset row whose input value is not null does not exist. fillDefaultValue(EmptyRow) } else { - resetStates(rows) - if (ignoreNulls) { - findNextRowWithNonNullInput() - if (nextSelectedRow == EmptyRow) { - // Use default values since the offset row whose input value is not null does not exist. - fillDefaultValue(EmptyRow) - } else { - projection(nextSelectedRow) - } - } else { - var selectedRow: UnsafeRow = null - // drain the first few rows if offset is larger than one - while (inputIndex < offset) { - selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputIndex += 1 - } - projection(selectedRow) - } + projection(nextSelectedRow) } } - override def write(index: Int, current: InternalRow): Unit = { + override def prepareForRespectNulls(): Unit = { + var selectedRow: UnsafeRow = null + // drain the first few rows if offset is larger than one + while (inputIndex < offset) { + selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator) + inputIndex += 1 + } + projection(selectedRow) + } + + protected def doWrite(index: Int, current: InternalRow): Unit = { // The results are the same for each row in the partition, and have been evaluated in prepare. // Don't need to recalculate here. } @@ -370,27 +383,18 @@ class UnboundedPrecedingOffsetWindowFunctionFrame( offset: Int, ignoreNulls: Boolean = false) extends OffsetWindowFunctionFrameBase( - target, ordinal, expressions, inputSchema, newMutableProjection, offset) { + target, ordinal, expressions, inputSchema, newMutableProjection, offset, ignoreNulls) { assert(offset > 0) - override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { - if (offset > rows.length) { - fillDefaultValue(EmptyRow) - } else { - resetStates(rows) - if (ignoreNulls) { - findNextRowWithNonNullInput() - } else { - // drain the first few rows if offset is larger than one - while (inputIndex < offset) { - nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputIndex += 1 - } - } + override def prepareForRespectNulls(): Unit = { + // drain the first few rows if offset is larger than one + while (inputIndex < offset) { + nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator) + inputIndex += 1 } } - override def write(index: Int, current: InternalRow): Unit = { + protected def doWrite(index: Int, current: InternalRow): Unit = { if (index >= inputIndex - 1 && nextSelectedRow != null) { projection(nextSelectedRow) } else {