Skip to content

Commit 455af6b

Browse files
committed
[KYUUBI #360] correct getNextRowSet with FETCH_PRIOR FETCH_FIRST
1 parent 2307f1f commit 455af6b

File tree

12 files changed

+311
-14
lines changed

12 files changed

+311
-14
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.spark
19+
20+
private[engine] sealed trait FetchIterator[A] extends Iterator[A] {
21+
/**
22+
* Begin a fetch block, forward from the current position.
23+
* Resets the fetch start offset.
24+
*/
25+
def fetchNext(): Unit
26+
27+
/**
28+
* Begin a fetch block, moving the iterator back by offset from the start of the previous fetch
29+
* block start.
30+
* Resets the fetch start offset.
31+
*
32+
* @param offset the amount to move a fetch start position toward the prior direction.
33+
*/
34+
def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset)
35+
36+
/**
37+
* Begin a fetch block, moving the iterator to the given position.
38+
* Resets the fetch start offset.
39+
*
40+
* @param pos index to move a position of iterator.
41+
*/
42+
def fetchAbsolute(pos: Long): Unit
43+
44+
def getFetchStart: Long
45+
46+
def getPosition: Long
47+
}
48+
49+
private[engine] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] {
50+
private var fetchStart: Long = 0
51+
52+
private var position: Long = 0
53+
54+
override def fetchNext(): Unit = fetchStart = position
55+
56+
override def fetchAbsolute(pos: Long): Unit = {
57+
position = (pos max 0) min src.length
58+
fetchStart = position
59+
}
60+
61+
override def getFetchStart: Long = fetchStart
62+
63+
override def getPosition: Long = position
64+
65+
override def hasNext: Boolean = position < src.length
66+
67+
override def next(): A = {
68+
position += 1
69+
src(position.toInt - 1)
70+
}
71+
}
72+
73+
private[engine] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] {
74+
private var iter: Iterator[A] = iterable.iterator
75+
76+
private var fetchStart: Long = 0
77+
78+
private var position: Long = 0
79+
80+
override def fetchNext(): Unit = fetchStart = position
81+
82+
override def fetchAbsolute(pos: Long): Unit = {
83+
val newPos = pos max 0
84+
if (newPos < position) resetPosition()
85+
while (position < newPos && hasNext) next()
86+
fetchStart = position
87+
}
88+
89+
override def getFetchStart: Long = fetchStart
90+
91+
override def getPosition: Long = position
92+
93+
override def hasNext: Boolean = iter.hasNext
94+
95+
override def next(): A = {
96+
position += 1
97+
iter.next()
98+
}
99+
100+
private def resetPosition(): Unit = {
101+
if (position != 0) {
102+
iter = iterable.iterator
103+
position = 0
104+
fetchStart = 0
105+
}
106+
}
107+
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._
2424
import org.apache.spark.sql.types._
2525

2626
import org.apache.kyuubi.{KyuubiSQLException, Logging}
27+
import org.apache.kyuubi.engine.spark.ArrayFetchIterator
2728
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
2829
import org.apache.kyuubi.operation.{OperationState, OperationType}
2930
import org.apache.kyuubi.operation.log.OperationLog
@@ -74,7 +75,7 @@ class ExecuteStatement(
7475
debug(s"original result queryExecution: ${result.queryExecution}")
7576
val castedResult = result.select(castCols: _*)
7677
debug(s"casted result queryExecution: ${castedResult.queryExecution}")
77-
iter = castedResult.collect().toList.iterator
78+
iter = new ArrayFetchIterator(castedResult.collect())
7879
setState(OperationState.FINISHED)
7980
} catch {
8081
onError(cancel = true)

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetCatalogs.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
2020
import org.apache.spark.sql.SparkSession
2121
import org.apache.spark.sql.types.StructType
2222

23+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2324
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
2425
import org.apache.kyuubi.operation.OperationType
2526
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT
@@ -35,7 +36,7 @@ class GetCatalogs(spark: SparkSession, session: Session)
3536

3637
override protected def runInternal(): Unit = {
3738
try {
38-
iter = SparkCatalogShim().getCatalogs(spark).toIterator
39+
iter = new IterableFetchIterator(SparkCatalogShim().getCatalogs(spark).toList)
3940
} catch onError()
4041
}
4142
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetColumns.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
2020
import org.apache.spark.sql.SparkSession
2121
import org.apache.spark.sql.types._
2222

23+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2324
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
2425
import org.apache.kyuubi.operation.OperationType
2526
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
@@ -88,9 +89,8 @@ class GetColumns(
8889
val schemaPattern = toJavaRegex(schemaName)
8990
val tablePattern = toJavaRegex(tableName)
9091
val columnPattern = toJavaRegex(columnName)
91-
iter = SparkCatalogShim()
92-
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern)
93-
.toList.iterator
92+
iter = new IterableFetchIterator(SparkCatalogShim()
93+
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern).toList)
9494
} catch {
9595
onError()
9696
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetFunctions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.sql.DatabaseMetaData
2222
import org.apache.spark.sql.{Row, SparkSession}
2323
import org.apache.spark.sql.types.StructType
2424

25+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2526
import org.apache.kyuubi.operation.OperationType
2627
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
2728
import org.apache.kyuubi.session.Session
@@ -70,7 +71,7 @@ class GetFunctions(
7071
info.getClassName)
7172
}
7273
}
73-
iter = a.toList.iterator
74+
iter = new IterableFetchIterator(a.toList)
7475
} catch {
7576
onError()
7677
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetSchemas.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
2020
import org.apache.spark.sql.SparkSession
2121
import org.apache.spark.sql.types.StructType
2222

23+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2324
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
2425
import org.apache.kyuubi.operation.OperationType
2526
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
@@ -42,7 +43,7 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
4243
try {
4344
val schemaPattern = toJavaRegex(schema)
4445
val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern)
45-
iter = rows.toList.toIterator
46+
iter = new IterableFetchIterator(rows)
4647
} catch onError()
4748
}
4849
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTableTypes.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
2020
import org.apache.spark.sql.{Row, SparkSession}
2121
import org.apache.spark.sql.types.StructType
2222

23+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2324
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
2425
import org.apache.kyuubi.operation.OperationType
2526
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
@@ -33,6 +34,6 @@ class GetTableTypes(spark: SparkSession, session: Session)
3334
}
3435

3536
override protected def runInternal(): Unit = {
36-
iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator
37+
iter = new IterableFetchIterator(SparkCatalogShim.sparkTableTypes.map(Row(_)).toList)
3738
}
3839
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTables.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
2020
import org.apache.spark.sql.SparkSession
2121
import org.apache.spark.sql.types.StructType
2222

23+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2324
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
2425
import org.apache.kyuubi.operation.OperationType
2526
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
@@ -73,7 +74,7 @@ class GetTables(
7374
} else {
7475
catalogTablesAndViews
7576
}
76-
iter = allTableAndViews.toList.iterator
77+
iter = new IterableFetchIterator(allTableAndViews)
7778
} catch {
7879
onError()
7980
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTypeInfo.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.sql.Types._
2222
import org.apache.spark.sql.{Row, SparkSession}
2323
import org.apache.spark.sql.types.StructType
2424

25+
import org.apache.kyuubi.engine.spark.IterableFetchIterator
2526
import org.apache.kyuubi.operation.OperationType
2627
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
2728
import org.apache.kyuubi.session.Session
@@ -83,7 +84,7 @@ class GetTypeInfo(spark: SparkSession, session: Session)
8384
}
8485

8586
override protected def runInternal(): Unit = {
86-
iter = Seq(
87+
iter = new IterableFetchIterator(Seq(
8788
toRow("VOID", NULL),
8889
toRow("BOOLEAN", BOOLEAN),
8990
toRow("TINYINT", TINYINT, 3),
@@ -101,6 +102,6 @@ class GetTypeInfo(spark: SparkSession, session: Session)
101102
toRow("MAP", JAVA_OBJECT),
102103
toRow("STRUCT", STRUCT),
103104
toRow("INTERVAL", OTHER)
104-
).toList.iterator
105+
))
105106
}
106107
}

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ import org.apache.spark.sql.{Row, SparkSession}
2525
import org.apache.spark.sql.types.StructType
2626

2727
import org.apache.kyuubi.KyuubiSQLException
28+
import org.apache.kyuubi.engine.spark.FetchIterator
2829
import org.apache.kyuubi.operation.{AbstractOperation, OperationState}
29-
import org.apache.kyuubi.operation.FetchOrientation.FetchOrientation
30+
import org.apache.kyuubi.operation.FetchOrientation._
3031
import org.apache.kyuubi.operation.OperationState.OperationState
3132
import org.apache.kyuubi.operation.OperationType.OperationType
3233
import org.apache.kyuubi.operation.log.OperationLog
@@ -36,7 +37,7 @@ import org.apache.kyuubi.session.Session
3637
abstract class SparkOperation(spark: SparkSession, opType: OperationType, session: Session)
3738
extends AbstractOperation(opType, session) {
3839

39-
protected var iter: Iterator[Row] = _
40+
protected var iter: FetchIterator[Row] = _
4041

4142
protected final val operationLog: OperationLog =
4243
OperationLog.createOperationLog(session.handle, getHandle)
@@ -130,8 +131,15 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
130131
validateDefaultFetchOrientation(order)
131132
assertState(OperationState.FINISHED)
132133
setHasResultSet(true)
134+
order match {
135+
case FETCH_NEXT => iter.fetchNext()
136+
case FETCH_PRIOR => iter.fetchPrior(rowSetSize);
137+
case FETCH_FIRST => iter.fetchAbsolute(0);
138+
}
133139
val taken = iter.take(rowSetSize)
134-
RowSet.toTRowSet(taken.toList, resultSchema, getProtocolVersion)
140+
val resultRowSet = RowSet.toTRowSet(taken.toList, resultSchema, getProtocolVersion)
141+
resultRowSet.setStartRowOffset(iter.getPosition)
142+
resultRowSet
135143
}
136144

137145
override def shouldRunAsync: Boolean = false

0 commit comments

Comments
 (0)