Skip to content

Commit a713a7e

Browse files
Dooyoung HwangHyukjinKwon
authored andcommitted
[SPARK-33655][SQL] Improve performance of processing FETCH_PRIOR
### What changes were proposed in this pull request? Currently, when a client requests FETCH_PRIOR to Thriftserver, Thriftserver reiterates from the start position. Because Thriftserver caches a query result with an array when THRIFTSERVER_INCREMENTAL_COLLECT feature is off, FETCH_PRIOR can be implemented without reiterating the result. A trait FeatureIterator is added in order to separate the implementation for iterator and an array. Also, FeatureIterator supports moves cursor with absolute position, which will be useful for the implementation of FETCH_RELATIVE, FETCH_ABSOLUTE. ### Why are the changes needed? For better performance of Thriftserver. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? FetchIteratorSuite Closes #30600 from Dooyoung-Hwang/refactor_with_fetch_iterator. Authored-by: Dooyoung Hwang <dooyoung.hwang@sk.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 48f93af commit a713a7e

File tree

3 files changed

+256
-54
lines changed

3 files changed

+256
-54
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.thriftserver
19+
20+
private[hive] sealed trait FetchIterator[A] extends Iterator[A] {
21+
/**
22+
* Begin a fetch block, forward from the current position.
23+
* Resets the fetch start offset.
24+
*/
25+
def fetchNext(): Unit
26+
27+
/**
28+
* Begin a fetch block, moving the iterator back by offset from the start of the previous fetch
29+
* block start.
30+
* Resets the fetch start offset.
31+
*
32+
* @param offset the amount to move a fetch start position toward the prior direction.
33+
*/
34+
def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset)
35+
36+
/**
37+
* Begin a fetch block, moving the iterator to the given position.
38+
* Resets the fetch start offset.
39+
*
40+
* @param pos index to move a position of iterator.
41+
*/
42+
def fetchAbsolute(pos: Long): Unit
43+
44+
def getFetchStart: Long
45+
46+
def getPosition: Long
47+
}
48+
49+
private[hive] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] {
50+
private var fetchStart: Long = 0
51+
52+
private var position: Long = 0
53+
54+
override def fetchNext(): Unit = fetchStart = position
55+
56+
override def fetchAbsolute(pos: Long): Unit = {
57+
position = (pos max 0) min src.length
58+
fetchStart = position
59+
}
60+
61+
override def getFetchStart: Long = fetchStart
62+
63+
override def getPosition: Long = position
64+
65+
override def hasNext: Boolean = position < src.length
66+
67+
override def next(): A = {
68+
position += 1
69+
src(position.toInt - 1)
70+
}
71+
}
72+
73+
private[hive] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] {
74+
private var iter: Iterator[A] = iterable.iterator
75+
76+
private var fetchStart: Long = 0
77+
78+
private var position: Long = 0
79+
80+
override def fetchNext(): Unit = fetchStart = position
81+
82+
override def fetchAbsolute(pos: Long): Unit = {
83+
val newPos = pos max 0
84+
if (newPos < position) resetPosition()
85+
while (position < newPos && hasNext) next()
86+
fetchStart = position
87+
}
88+
89+
override def getFetchStart: Long = fetchStart
90+
91+
override def getPosition: Long = position
92+
93+
override def hasNext: Boolean = iter.hasNext
94+
95+
override def next(): A = {
96+
position += 1
97+
iter.next()
98+
}
99+
100+
private def resetPosition(): Unit = {
101+
if (position != 0) {
102+
iter = iterable.iterator
103+
position = 0
104+
fetchStart = 0
105+
}
106+
}
107+
}

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala

Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,7 @@ private[hive] class SparkExecuteStatementOperation(
6969

7070
private var result: DataFrame = _
7171

72-
// We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST.
73-
// This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`.
74-
// In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution.
75-
private var resultList: Option[Array[SparkRow]] = _
76-
private var previousFetchEndOffset: Long = 0
77-
private var previousFetchStartOffset: Long = 0
78-
private var iter: Iterator[SparkRow] = _
72+
private var iter: FetchIterator[SparkRow] = _
7973
private var dataTypes: Array[DataType] = _
8074

8175
private lazy val resultSchema: TableSchema = {
@@ -148,43 +142,14 @@ private[hive] class SparkExecuteStatementOperation(
148142
setHasResultSet(true)
149143
val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false)
150144

151-
// Reset iter when FETCH_FIRST or FETCH_PRIOR
152-
if ((order.equals(FetchOrientation.FETCH_FIRST) ||
153-
order.equals(FetchOrientation.FETCH_PRIOR)) && previousFetchEndOffset != 0) {
154-
// Reset the iterator to the beginning of the query.
155-
iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
156-
resultList = None
157-
result.toLocalIterator.asScala
158-
} else {
159-
if (resultList.isEmpty) {
160-
resultList = Some(result.collect())
161-
}
162-
resultList.get.iterator
163-
}
164-
}
165-
166-
var resultOffset = {
167-
if (order.equals(FetchOrientation.FETCH_FIRST)) {
168-
logInfo(s"FETCH_FIRST request with $statementId. Resetting to resultOffset=0")
169-
0
170-
} else if (order.equals(FetchOrientation.FETCH_PRIOR)) {
171-
// TODO: FETCH_PRIOR should be handled more efficiently than rewinding to beginning and
172-
// reiterating.
173-
val targetOffset = math.max(previousFetchStartOffset - maxRowsL, 0)
174-
logInfo(s"FETCH_PRIOR request with $statementId. Resetting to resultOffset=$targetOffset")
175-
var off = 0
176-
while (off < targetOffset && iter.hasNext) {
177-
iter.next()
178-
off += 1
179-
}
180-
off
181-
} else { // FETCH_NEXT
182-
previousFetchEndOffset
183-
}
145+
if (order.equals(FetchOrientation.FETCH_FIRST)) {
146+
iter.fetchAbsolute(0)
147+
} else if (order.equals(FetchOrientation.FETCH_PRIOR)) {
148+
iter.fetchPrior(maxRowsL)
149+
} else {
150+
iter.fetchNext()
184151
}
185-
186-
resultRowSet.setStartOffset(resultOffset)
187-
previousFetchStartOffset = resultOffset
152+
resultRowSet.setStartOffset(iter.getPosition)
188153
if (!iter.hasNext) {
189154
resultRowSet
190155
} else {
@@ -206,11 +171,9 @@ private[hive] class SparkExecuteStatementOperation(
206171
}
207172
resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]])
208173
curRow += 1
209-
resultOffset += 1
210174
}
211-
previousFetchEndOffset = resultOffset
212175
log.info(s"Returning result set with ${curRow} rows from offsets " +
213-
s"[$previousFetchStartOffset, $previousFetchEndOffset) with $statementId")
176+
s"[${iter.getFetchStart}, ${iter.getPosition}) with $statementId")
214177
resultRowSet
215178
}
216179
}
@@ -326,14 +289,12 @@ private[hive] class SparkExecuteStatementOperation(
326289
logDebug(result.queryExecution.toString())
327290
HiveThriftServer2.eventManager.onStatementParsed(statementId,
328291
result.queryExecution.toString())
329-
iter = {
330-
if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
331-
resultList = None
332-
result.toLocalIterator.asScala
333-
} else {
334-
resultList = Some(result.collect())
335-
resultList.get.iterator
336-
}
292+
iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
293+
new IterableFetchIterator[SparkRow](new Iterable[SparkRow] {
294+
override def iterator: Iterator[SparkRow] = result.toLocalIterator.asScala
295+
})
296+
} else {
297+
new ArrayFetchIterator[SparkRow](result.collect())
337298
}
338299
dataTypes = result.schema.fields.map(_.dataType)
339300
} catch {
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.thriftserver
19+
20+
import org.apache.spark.SparkFunSuite
21+
22+
class FetchIteratorSuite extends SparkFunSuite {
23+
24+
private def getRows(fetchIter: FetchIterator[Int], maxRowCount: Int): Seq[Int] = {
25+
for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next()
26+
}
27+
28+
test("SPARK-33655: Test fetchNext and fetchPrior") {
29+
val testData = 0 until 10
30+
31+
def iteratorTest(fetchIter: FetchIterator[Int]): Unit = {
32+
fetchIter.fetchNext()
33+
assert(fetchIter.getFetchStart == 0)
34+
assert(fetchIter.getPosition == 0)
35+
assertResult(0 until 2)(getRows(fetchIter, 2))
36+
assert(fetchIter.getFetchStart == 0)
37+
assert(fetchIter.getPosition == 2)
38+
39+
fetchIter.fetchNext()
40+
assert(fetchIter.getFetchStart == 2)
41+
assert(fetchIter.getPosition == 2)
42+
assertResult(2 until 3)(getRows(fetchIter, 1))
43+
assert(fetchIter.getFetchStart == 2)
44+
assert(fetchIter.getPosition == 3)
45+
46+
fetchIter.fetchPrior(2)
47+
assert(fetchIter.getFetchStart == 0)
48+
assert(fetchIter.getPosition == 0)
49+
assertResult(0 until 3)(getRows(fetchIter, 3))
50+
assert(fetchIter.getFetchStart == 0)
51+
assert(fetchIter.getPosition == 3)
52+
53+
fetchIter.fetchNext()
54+
assert(fetchIter.getFetchStart == 3)
55+
assert(fetchIter.getPosition == 3)
56+
assertResult(3 until 8)(getRows(fetchIter, 5))
57+
assert(fetchIter.getFetchStart == 3)
58+
assert(fetchIter.getPosition == 8)
59+
60+
fetchIter.fetchPrior(2)
61+
assert(fetchIter.getFetchStart == 1)
62+
assert(fetchIter.getPosition == 1)
63+
assertResult(1 until 4)(getRows(fetchIter, 3))
64+
assert(fetchIter.getFetchStart == 1)
65+
assert(fetchIter.getPosition == 4)
66+
67+
fetchIter.fetchNext()
68+
assert(fetchIter.getFetchStart == 4)
69+
assert(fetchIter.getPosition == 4)
70+
assertResult(4 until 10)(getRows(fetchIter, 10))
71+
assert(fetchIter.getFetchStart == 4)
72+
assert(fetchIter.getPosition == 10)
73+
74+
fetchIter.fetchNext()
75+
assert(fetchIter.getFetchStart == 10)
76+
assert(fetchIter.getPosition == 10)
77+
assertResult(Seq.empty[Int])(getRows(fetchIter, 10))
78+
assert(fetchIter.getFetchStart == 10)
79+
assert(fetchIter.getPosition == 10)
80+
81+
fetchIter.fetchPrior(20)
82+
assert(fetchIter.getFetchStart == 0)
83+
assert(fetchIter.getPosition == 0)
84+
assertResult(0 until 3)(getRows(fetchIter, 3))
85+
assert(fetchIter.getFetchStart == 0)
86+
assert(fetchIter.getPosition == 3)
87+
}
88+
iteratorTest(new ArrayFetchIterator[Int](testData.toArray))
89+
iteratorTest(new IterableFetchIterator[Int](testData))
90+
}
91+
92+
test("SPARK-33655: Test fetchAbsolute") {
93+
val testData = 0 until 10
94+
95+
def iteratorTest(fetchIter: FetchIterator[Int]): Unit = {
96+
fetchIter.fetchNext()
97+
assert(fetchIter.getFetchStart == 0)
98+
assert(fetchIter.getPosition == 0)
99+
assertResult(0 until 5)(getRows(fetchIter, 5))
100+
assert(fetchIter.getFetchStart == 0)
101+
assert(fetchIter.getPosition == 5)
102+
103+
fetchIter.fetchAbsolute(2)
104+
assert(fetchIter.getFetchStart == 2)
105+
assert(fetchIter.getPosition == 2)
106+
assertResult(2 until 5)(getRows(fetchIter, 3))
107+
assert(fetchIter.getFetchStart == 2)
108+
assert(fetchIter.getPosition == 5)
109+
110+
fetchIter.fetchAbsolute(7)
111+
assert(fetchIter.getFetchStart == 7)
112+
assert(fetchIter.getPosition == 7)
113+
assertResult(7 until 8)(getRows(fetchIter, 1))
114+
assert(fetchIter.getFetchStart == 7)
115+
assert(fetchIter.getPosition == 8)
116+
117+
fetchIter.fetchAbsolute(20)
118+
assert(fetchIter.getFetchStart == 10)
119+
assert(fetchIter.getPosition == 10)
120+
assertResult(Seq.empty[Int])(getRows(fetchIter, 1))
121+
assert(fetchIter.getFetchStart == 10)
122+
assert(fetchIter.getPosition == 10)
123+
124+
fetchIter.fetchAbsolute(0)
125+
assert(fetchIter.getFetchStart == 0)
126+
assert(fetchIter.getPosition == 0)
127+
assertResult(0 until 3)(getRows(fetchIter, 3))
128+
assert(fetchIter.getFetchStart == 0)
129+
assert(fetchIter.getPosition == 3)
130+
}
131+
iteratorTest(new ArrayFetchIterator[Int](testData.toArray))
132+
iteratorTest(new IterableFetchIterator[Int](testData))
133+
}
134+
}

0 commit comments

Comments
 (0)