diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 65a769da70aea..164bfd42d6e4a 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -58,6 +58,8 @@ license: | - In Spark 3.1, creating or altering a view will capture runtime SQL configs and store them as view properties. These configs will be applied during the parsing and analysis phases of the view resolution. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.useCurrentConfigsForView` to `true`. + - Since Spark 3.1, CHAR/CHARACTER and VARCHAR types are supported in the table schema. Table scan/insertion will respect the char/varchar semantic. If char/varchar is used in places other than table schema, an exception will be thrown (CAST is an exception that simply treats char/varchar as string like before). To restore the behavior before Spark 3.1, which treats them as STRING types and ignores a length parameter, e.g. `CHAR(4)`, you can set `spark.sql.legacy.charVarcharAsString` to `true`. + ## Upgrading from Spark SQL 3.0 to 3.0.1 - In Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Since version 3.0.1, the timestamp type inference is disabled by default. Set the JSON option `inferTimestamp` to `true` to enable such type inference. diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index c3e17dc22eed0..08ba07aa8de63 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -144,14 +144,18 @@ SELECT * FROM t; The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input. - - `element_at`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. - - `element_at`: This function throws `NoSuchElementException` if key does not exist in map. + - `element_at`: + - This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. + - This function throws `NoSuchElementException` if key does not exist in map. - `elt`: This function throws `ArrayIndexOutOfBoundsException` if using invalid indices. - `parse_url`: This function throws `IllegalArgumentException` if an input string is not a valid url. - - `to_date` This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. - - `to_timestamp` This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. - - `unix_timestamp` This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. - - `to_unix_timestamp` This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. + - `to_date`: This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. + - `to_timestamp`: This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. + - `unix_timestamp`: This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. + - `to_unix_timestamp`: This function should fail with an exception if the input string can't be parsed, or the pattern string is invalid. + - `make_date`: This function should fail with an exception if the result date is invalid. + - `make_timestamp`: This function should fail with an exception if the result timestamp is invalid. + - `make_interval`: This function should fail with an exception if the result interval is invalid. ### SQL Operators diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7f791e02a392b..618faef2d58b3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1069,7 +1069,7 @@ private[spark] class Client( logError(s"Application $appId not found.") cleanupStagingDir() return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None) - case NonFatal(e) => + case NonFatal(e) if !e.isInstanceOf[InterruptedIOException] => val msg = s"Failed to contact YARN for application $appId." logError(msg, e) // Don't necessarily clean up staging dir because status is unknown diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index cb0de5a0d50b4..8a55e612ce719 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster +import java.io.InterruptedIOException + import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState @@ -121,7 +123,8 @@ private[spark] class YarnClientSchedulerBackend( allowInterrupt = false sc.stop() } catch { - case e: InterruptedException => logInfo("Interrupting monitor thread") + case _: InterruptedException | _: InterruptedIOException => + logInfo("Interrupting monitor thread") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index ad6cf959a69c6..1f3f762662252 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -90,7 +90,7 @@ trait AliasHelper { exprId = a.exprId, qualifier = a.qualifier, explicitMetadata = Some(a.metadata), - deniedMetadataKeys = a.deniedMetadataKeys) + nonInheritableMetadataKeys = a.nonInheritableMetadataKeys) case a: MultiAlias => a.copy(child = trimAliases(a.child)) case other => trimAliases(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 22aabd3c6b30b..badc2ecc9cb28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -143,14 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * fully qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. - * @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when - * inheriting the metadata from the child. + * @param nonInheritableMetadataKeys Keys of metadata entries that are supposed to be removed when + * inheriting the metadata from the child. */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifier: Seq[String] = Seq.empty, val explicitMetadata: Option[Metadata] = None, - val deniedMetadataKeys: Seq[String] = Seq.empty) + val nonInheritableMetadataKeys: Seq[String] = Seq.empty) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -172,7 +172,7 @@ case class Alias(child: Expression, name: String)( child match { case named: NamedExpression => val builder = new MetadataBuilder().withMetadata(named.metadata) - deniedMetadataKeys.foreach(builder.remove) + nonInheritableMetadataKeys.foreach(builder.remove) builder.build() case _ => Metadata.empty @@ -181,7 +181,10 @@ case class Alias(child: Expression, name: String)( } def newInstance(): NamedExpression = - Alias(child, name)(qualifier = qualifier, explicitMetadata = explicitMetadata) + Alias(child, name)( + qualifier = qualifier, + explicitMetadata = explicitMetadata, + nonInheritableMetadataKeys = nonInheritableMetadataKeys) override def toAttribute: Attribute = { if (resolved) { @@ -201,7 +204,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil + exprId :: qualifier :: explicitMetadata :: nonInheritableMetadataKeys :: Nil } override def hashCode(): Int = { @@ -212,7 +215,8 @@ case class Alias(child: Expression, name: String)( override def equals(other: Any): Boolean = other match { case a: Alias => name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier && - explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys + explicitMetadata == a.explicitMetadata && + nonInheritableMetadataKeys == a.nonInheritableMetadataKeys case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index ef3de4738c75c..698ece4f9e69f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -53,6 +53,7 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case p: LogicalPlan => p transformExpressions { case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) case cw @ CaseWhen(branches, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index eb52c5b74772c..6fc31c94e47eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -48,6 +48,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test("replace null inside filter and join conditions") { testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) } test("Not expected type - replaceNullWithFalse") { @@ -64,6 +65,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(null, BooleanType)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested expressions in branches of If") { @@ -73,6 +75,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { UnresolvedAttribute("b") && Literal(null, BooleanType)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in elseValue of CaseWhen") { @@ -83,6 +86,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedCond = CaseWhen(branches, FalseLiteral) testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in branch values of CaseWhen") { @@ -92,6 +96,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val originalCond = CaseWhen(branches, Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside CaseWhen") { @@ -108,6 +113,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in complex CaseWhen expressions") { @@ -127,6 +133,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in Or") { @@ -134,12 +141,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val expectedCond = UnresolvedAttribute("b") testFilter(originalCond, expectedCond) testJoin(originalCond, expectedCond) + testDelete(originalCond, expectedCond) } test("replace null in And") { val originalCond = And(UnresolvedAttribute("b"), Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace nulls in nested And/Or expressions") { @@ -148,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in And inside branches of If") { @@ -157,6 +167,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { And(UnresolvedAttribute("b"), Literal(null, BooleanType))) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside And") { @@ -168,6 +179,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { And(FalseLiteral, UnresolvedAttribute("b")))) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in branches of If inside another If") { @@ -177,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("replace null in CaseWhen inside another CaseWhen") { @@ -184,6 +197,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) testFilter(originalCond, expectedCond = FalseLiteral) testJoin(originalCond, expectedCond = FalseLiteral) + testDelete(originalCond, expectedCond = FalseLiteral) } test("inability to replace null in non-boolean branches of If") { @@ -196,6 +210,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean values of CaseWhen") { @@ -210,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val condition = CaseWhen(branches) testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -222,6 +238,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { FalseLiteral) testFilter(originalCond = condition, expectedCond = condition) testJoin(originalCond = condition, expectedCond = condition) + testDelete(originalCond = condition, expectedCond = condition) } test("replace null in If used as a join condition") { @@ -353,6 +370,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) } + private def testDelete(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, expr) => DeleteFromTable(rel, Some(expr)), originalCond, expectedCond) + } + private def testHigherOrderFunc( argument: Expression, createExpr: (Expression, Expression) => Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4ef23d7e31c59..539ef8dfe2665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1164,10 +1164,11 @@ class Column(val expr: Expression) extends Logging { * @since 2.0.0 */ def name(alias: String): Column = withExpr { - // SPARK-33536: The Alias is no longer a column reference after converting to an attribute. - // These denied metadata keys are used to strip the column reference related metadata for - // the Alias. So it won't be caught as a column reference in DetectAmbiguousSelfJoin. - Alias(expr, alias)(deniedMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) + // SPARK-33536: an alias is no longer a column reference. Therefore, + // we should not inherit the column reference related metadata in an alias + // so that it is not caught as a column reference in DetectAmbiguousSelfJoin. + Alias(expr, alias)( + nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 5e1c6ba92803d..7c19f98b762f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -577,8 +577,8 @@ class ColumnarAlias(child: ColumnarExpression, name: String)( override val exprId: ExprId = NamedExpression.newExprId, override val qualifier: Seq[String] = Seq.empty, override val explicitMetadata: Option[Metadata] = None, - override val deniedMetadataKeys: Seq[String] = Seq.empty) - extends Alias(child, name)(exprId, qualifier, explicitMetadata, deniedMetadataKeys) + override val nonInheritableMetadataKeys: Seq[String] = Seq.empty) + extends Alias(child, name)(exprId, qualifier, explicitMetadata, nonInheritableMetadataKeys) with ColumnarExpression { override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch) @@ -715,7 +715,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match { case a: Alias => new ColumnarAlias(replaceWithColumnarExpression(a.child), - a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.deniedMetadataKeys) + a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys) case att: AttributeReference => new ColumnarAttributeReference(att.name, att.dataType, att.nullable, att.metadata)(att.exprId, att.qualifier) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala new file mode 100644 index 0000000000000..b9db657952b56 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] sealed trait FetchIterator[A] extends Iterator[A] { + /** + * Begin a fetch block, forward from the current position. + * Resets the fetch start offset. + */ + def fetchNext(): Unit + + /** + * Begin a fetch block, moving the iterator back by offset from the start of the previous fetch + * block start. + * Resets the fetch start offset. + * + * @param offset the amount to move a fetch start position toward the prior direction. + */ + def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset) + + /** + * Begin a fetch block, moving the iterator to the given position. + * Resets the fetch start offset. + * + * @param pos index to move a position of iterator. + */ + def fetchAbsolute(pos: Long): Unit + + def getFetchStart: Long + + def getPosition: Long +} + +private[hive] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] { + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def fetchAbsolute(pos: Long): Unit = { + position = (pos max 0) min src.length + fetchStart = position + } + + override def getFetchStart: Long = fetchStart + + override def getPosition: Long = position + + override def hasNext: Boolean = position < src.length + + override def next(): A = { + position += 1 + src(position.toInt - 1) + } +} + +private[hive] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] { + private var iter: Iterator[A] = iterable.iterator + + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def fetchAbsolute(pos: Long): Unit = { + val newPos = pos max 0 + if (newPos < position) resetPosition() + while (position < newPos && hasNext) next() + fetchStart = position + } + + override def getFetchStart: Long = fetchStart + + override def getPosition: Long = position + + override def hasNext: Boolean = iter.hasNext + + override def next(): A = { + position += 1 + iter.next() + } + + private def resetPosition(): Unit = { + if (position != 0) { + iter = iterable.iterator + position = 0 + fetchStart = 0 + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f7a4be9591818..c4ae035e1f836 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -69,13 +69,7 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ - // We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST. - // This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`. - // In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution. - private var resultList: Option[Array[SparkRow]] = _ - private var previousFetchEndOffset: Long = 0 - private var previousFetchStartOffset: Long = 0 - private var iter: Iterator[SparkRow] = _ + private var iter: FetchIterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private lazy val resultSchema: TableSchema = { @@ -148,43 +142,14 @@ private[hive] class SparkExecuteStatementOperation( setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) - // Reset iter when FETCH_FIRST or FETCH_PRIOR - if ((order.equals(FetchOrientation.FETCH_FIRST) || - order.equals(FetchOrientation.FETCH_PRIOR)) && previousFetchEndOffset != 0) { - // Reset the iterator to the beginning of the query. - iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - resultList = None - result.toLocalIterator.asScala - } else { - if (resultList.isEmpty) { - resultList = Some(result.collect()) - } - resultList.get.iterator - } - } - - var resultOffset = { - if (order.equals(FetchOrientation.FETCH_FIRST)) { - logInfo(s"FETCH_FIRST request with $statementId. Resetting to resultOffset=0") - 0 - } else if (order.equals(FetchOrientation.FETCH_PRIOR)) { - // TODO: FETCH_PRIOR should be handled more efficiently than rewinding to beginning and - // reiterating. - val targetOffset = math.max(previousFetchStartOffset - maxRowsL, 0) - logInfo(s"FETCH_PRIOR request with $statementId. Resetting to resultOffset=$targetOffset") - var off = 0 - while (off < targetOffset && iter.hasNext) { - iter.next() - off += 1 - } - off - } else { // FETCH_NEXT - previousFetchEndOffset - } + if (order.equals(FetchOrientation.FETCH_FIRST)) { + iter.fetchAbsolute(0) + } else if (order.equals(FetchOrientation.FETCH_PRIOR)) { + iter.fetchPrior(maxRowsL) + } else { + iter.fetchNext() } - - resultRowSet.setStartOffset(resultOffset) - previousFetchStartOffset = resultOffset + resultRowSet.setStartOffset(iter.getPosition) if (!iter.hasNext) { resultRowSet } else { @@ -206,11 +171,9 @@ private[hive] class SparkExecuteStatementOperation( } resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) curRow += 1 - resultOffset += 1 } - previousFetchEndOffset = resultOffset log.info(s"Returning result set with ${curRow} rows from offsets " + - s"[$previousFetchStartOffset, $previousFetchEndOffset) with $statementId") + s"[${iter.getFetchStart}, ${iter.getPosition}) with $statementId") resultRowSet } } @@ -326,14 +289,12 @@ private[hive] class SparkExecuteStatementOperation( logDebug(result.queryExecution.toString()) HiveThriftServer2.eventManager.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - resultList = None - result.toLocalIterator.asScala - } else { - resultList = Some(result.collect()) - resultList.get.iterator - } + iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { + new IterableFetchIterator[SparkRow](new Iterable[SparkRow] { + override def iterator: Iterator[SparkRow] = result.toLocalIterator.asScala + }) + } else { + new ArrayFetchIterator[SparkRow](result.collect()) } dataTypes = result.schema.fields.map(_.dataType) } catch { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala new file mode 100644 index 0000000000000..0fbdb8a9050c8 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkFunSuite + +class FetchIteratorSuite extends SparkFunSuite { + + private def getRows(fetchIter: FetchIterator[Int], maxRowCount: Int): Seq[Int] = { + for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next() + } + + test("SPARK-33655: Test fetchNext and fetchPrior") { + val testData = 0 until 10 + + def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 2)(getRows(fetchIter, 2)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 2) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 2) + assertResult(2 until 3)(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 3) + + fetchIter.fetchPrior(2) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 3) + assert(fetchIter.getPosition == 3) + assertResult(3 until 8)(getRows(fetchIter, 5)) + assert(fetchIter.getFetchStart == 3) + assert(fetchIter.getPosition == 8) + + fetchIter.fetchPrior(2) + assert(fetchIter.getFetchStart == 1) + assert(fetchIter.getPosition == 1) + assertResult(1 until 4)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 1) + assert(fetchIter.getPosition == 4) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 4) + assert(fetchIter.getPosition == 4) + assertResult(4 until 10)(getRows(fetchIter, 10)) + assert(fetchIter.getFetchStart == 4) + assert(fetchIter.getPosition == 10) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(fetchIter, 10)) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + + fetchIter.fetchPrior(20) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + } + iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) + iteratorTest(new IterableFetchIterator[Int](testData)) + } + + test("SPARK-33655: Test fetchAbsolute") { + val testData = 0 until 10 + + def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 5)(getRows(fetchIter, 5)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 5) + + fetchIter.fetchAbsolute(2) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 2) + assertResult(2 until 5)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 5) + + fetchIter.fetchAbsolute(7) + assert(fetchIter.getFetchStart == 7) + assert(fetchIter.getPosition == 7) + assertResult(7 until 8)(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 7) + assert(fetchIter.getPosition == 8) + + fetchIter.fetchAbsolute(20) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + + fetchIter.fetchAbsolute(0) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + } + iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) + iteratorTest(new IterableFetchIterator[Int](testData)) + } +}