diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala index 30bbf8b77b4..af514ceb3c0 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala @@ -39,7 +39,6 @@ import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.{HiveMetadataTests, SparkQueryTests} import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ import org.apache.kyuubi.util.KyuubiHadoopUtils -import org.apache.kyuubi.util.SparkVersionUtil.isSparkVersionAtLeast class SparkOperationSuite extends WithSparkSQLEngine with HiveMetadataTests with SparkQueryTests { @@ -93,12 +92,12 @@ class SparkOperationSuite extends WithSparkSQLEngine with HiveMetadataTests with .add("c17", "struct", nullable = true, "17") // since spark3.3.0 - if (SPARK_ENGINE_VERSION >= "3.3") { + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.3") { schema = schema.add("c18", "interval day", nullable = true, "18") .add("c19", "interval year", nullable = true, "19") } // since spark3.4.0 - if (SPARK_ENGINE_VERSION >= "3.4") { + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.4") { schema = schema.add("c20", "timestamp_ntz", nullable = true, "20") } @@ -511,7 +510,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with HiveMetadataTests with val status = tOpenSessionResp.getStatus val errorMessage = status.getErrorMessage assert(status.getStatusCode === TStatusCode.ERROR_STATUS) - if (isSparkVersionAtLeast("3.4")) { + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.4") { assert(errorMessage.contains("[SCHEMA_NOT_FOUND]")) assert(errorMessage.contains(s"The schema `$dbName` cannot be found.")) } else { diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala index d14224a842f..e3bb4ccb730 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala @@ -17,11 +17,11 @@ package org.apache.kyuubi.operation -import org.apache.kyuubi.IcebergSuiteMixin +import org.apache.kyuubi.{IcebergSuiteMixin, SPARK_COMPILE_VERSION} import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ -import org.apache.kyuubi.util.SparkVersionUtil.isSparkVersionAtLeast +import org.apache.kyuubi.util.SparkVersionUtil -trait IcebergMetadataTests extends HiveJDBCTestHelper with IcebergSuiteMixin { +trait IcebergMetadataTests extends HiveJDBCTestHelper with IcebergSuiteMixin with SparkVersionUtil { test("get catalogs") { withJdbcStatement() { statement => @@ -153,11 +153,11 @@ trait IcebergMetadataTests extends HiveJDBCTestHelper with IcebergSuiteMixin { "date", "timestamp", // SPARK-37931 - if (isSparkVersionAtLeast("3.3")) "struct" + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.3") "struct" else "struct<`X`: bigint, `Y`: double>", "binary", // SPARK-37931 - if (isSparkVersionAtLeast("3.3")) "struct" else "struct<`X`: string>") + if (SPARK_COMPILE_VERSION >= "3.3") "struct" else "struct<`X`: string>") val cols = dataTypes.zipWithIndex.map { case (dt, idx) => s"c$idx" -> dt } val (colNames, _) = cols.unzip diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkDataTypeTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkDataTypeTests.scala index 3164ae496b3..4c9b2d4f4e3 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkDataTypeTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkDataTypeTests.scala @@ -20,14 +20,15 @@ package org.apache.kyuubi.operation import java.sql.{Date, Timestamp} import org.apache.kyuubi.engine.SemanticVersion +import org.apache.kyuubi.util.SparkVersionUtil -trait SparkDataTypeTests extends HiveJDBCTestHelper { - protected lazy val SPARK_ENGINE_VERSION = sparkEngineMajorMinorVersion +trait SparkDataTypeTests extends HiveJDBCTestHelper with SparkVersionUtil { def resultFormat: String = "thrift" test("execute statement - select null") { - assume(resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_VERSION >= "3.2")) + assume( + resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_RUNTIME_VERSION >= "3.2")) withJdbcStatement() { statement => val resultSet = statement.executeQuery("SELECT NULL AS col") assert(resultSet.next()) @@ -199,7 +200,7 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { } test("execute statement - select timestamp_ntz") { - assume(SPARK_ENGINE_VERSION >= "3.4") + assume(SPARK_ENGINE_RUNTIME_VERSION >= "3.4") withJdbcStatement() { statement => val resultSet = statement.executeQuery( "SELECT make_timestamp_ntz(2022, 03, 24, 18, 08, 31.8888) AS col") @@ -213,7 +214,8 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { } test("execute statement - select daytime interval") { - assume(resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_VERSION >= "3.3")) + assume( + resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_RUNTIME_VERSION >= "3.3")) withJdbcStatement() { statement => Map( "interval 1 day 1 hour -60 minutes 30 seconds" -> @@ -242,7 +244,7 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { assert(resultSet.next()) val result = resultSet.getString("col") val metaData = resultSet.getMetaData - if (SPARK_ENGINE_VERSION < "3.2") { + if (SPARK_ENGINE_RUNTIME_VERSION < "3.2") { // for spark 3.1 and backwards assert(result === kv._2._2) assert(metaData.getPrecision(1) === Int.MaxValue) @@ -258,7 +260,8 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { } test("execute statement - select year/month interval") { - assume(resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_VERSION >= "3.3")) + assume( + resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_RUNTIME_VERSION >= "3.3")) withJdbcStatement() { statement => Map( "INTERVAL 2022 YEAR" -> Tuple2("2022-0", "2022 years"), @@ -271,7 +274,7 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { assert(resultSet.next()) val result = resultSet.getString("col") val metaData = resultSet.getMetaData - if (SPARK_ENGINE_VERSION < "3.2") { + if (SPARK_ENGINE_RUNTIME_VERSION < "3.2") { // for spark 3.1 and backwards assert(result === kv._2._2) assert(metaData.getPrecision(1) === Int.MaxValue) @@ -287,7 +290,8 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { } test("execute statement - select array") { - assume(resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_VERSION >= "3.2")) + assume( + resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_RUNTIME_VERSION >= "3.2")) withJdbcStatement() { statement => val resultSet = statement.executeQuery( "SELECT array() AS col1, array(1) AS col2, array(null) AS col3") @@ -305,7 +309,8 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { } test("execute statement - select map") { - assume(resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_VERSION >= "3.2")) + assume( + resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_RUNTIME_VERSION >= "3.2")) withJdbcStatement() { statement => val resultSet = statement.executeQuery( "SELECT map() AS col1, map(1, 2, 3, 4) AS col2, map(1, null) AS col3") @@ -323,7 +328,8 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { } test("execute statement - select struct") { - assume(resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_VERSION >= "3.2")) + assume( + resultFormat == "thrift" || (resultFormat == "arrow" && SPARK_ENGINE_RUNTIME_VERSION >= "3.2")) withJdbcStatement() { statement => val resultSet = statement.executeQuery( "SELECT struct('1', '2') AS col1," + @@ -342,15 +348,4 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper { assert(metaData.getScale(2) == 0) } } - - def sparkEngineMajorMinorVersion: SemanticVersion = { - var sparkRuntimeVer = "" - withJdbcStatement() { stmt => - val result = stmt.executeQuery("SELECT version()") - assert(result.next()) - sparkRuntimeVer = result.getString(1) - assert(!result.next()) - } - SemanticVersion(sparkRuntimeVer) - } } diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala index a42b05473a7..ff8b124813c 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala @@ -28,7 +28,6 @@ import org.apache.hive.service.rpc.thrift.{TExecuteStatementReq, TFetchResultsRe import org.apache.kyuubi.{KYUUBI_VERSION, Utils} import org.apache.kyuubi.config.KyuubiConf -import org.apache.kyuubi.util.SparkVersionUtil.isSparkVersionAtLeast trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper { @@ -187,7 +186,7 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper { withJdbcStatement("t") { statement => try { val assertTableOrViewNotfound: (Exception, String) => Unit = (e, tableName) => { - if (isSparkVersionAtLeast("3.4")) { + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.4") { assert(e.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]")) assert(e.getMessage.contains(s"The table or view `$tableName` cannot be found.")) } else { diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/util/SparkVersionUtil.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/util/SparkVersionUtil.scala index cd8409d10db..785015cc377 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/util/SparkVersionUtil.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/util/SparkVersionUtil.scala @@ -17,13 +17,22 @@ package org.apache.kyuubi.util -import org.apache.kyuubi.SPARK_COMPILE_VERSION import org.apache.kyuubi.engine.SemanticVersion +import org.apache.kyuubi.operation.HiveJDBCTestHelper -object SparkVersionUtil { - lazy val sparkSemanticVersion: SemanticVersion = SemanticVersion(SPARK_COMPILE_VERSION) +trait SparkVersionUtil { + this: HiveJDBCTestHelper => - def isSparkVersionAtLeast(ver: String): Boolean = { - sparkSemanticVersion.isVersionAtLeast(ver) + protected lazy val SPARK_ENGINE_RUNTIME_VERSION = sparkEngineMajorMinorVersion + + def sparkEngineMajorMinorVersion: SemanticVersion = { + var sparkRuntimeVer = "" + withJdbcStatement() { stmt => + val result = stmt.executeQuery("SELECT version()") + assert(result.next()) + sparkRuntimeVer = result.getString(1) + assert(!result.next()) + } + SemanticVersion(sparkRuntimeVer) } }