|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.spark.sql |
| 21 | + |
| 22 | +import java.io.File |
| 23 | +import java.nio.file.{Files, Paths} |
| 24 | + |
| 25 | +import scala.collection.JavaConverters._ |
| 26 | + |
| 27 | +import org.apache.spark.{SparkConf, SparkContext} |
| 28 | +import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile} |
| 29 | +import org.apache.spark.sql.internal.SQLConf |
| 30 | +import org.apache.spark.sql.test.TestSparkSession |
| 31 | + |
| 32 | +/** |
| 33 | + * Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we |
| 34 | + * copy Spark `TPCDSQueryTestSuite`. |
| 35 | + */ |
| 36 | +class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper { |
| 37 | + |
| 38 | + private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA") |
| 39 | + |
| 40 | + private val regenGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" |
| 41 | + |
| 42 | + // To make output results deterministic |
| 43 | + override protected def sparkConf: SparkConf = super.sparkConf |
| 44 | + .set(SQLConf.SHUFFLE_PARTITIONS.key, "1") |
| 45 | + |
| 46 | + protected override def createSparkSession: TestSparkSession = { |
| 47 | + new TestSparkSession(new SparkContext("local[1]", this.getClass.getSimpleName, sparkConf)) |
| 48 | + } |
| 49 | + |
| 50 | + // We use SF=1 table data here, so we cannot use SF=100 stats |
| 51 | + protected override val injectStats: Boolean = false |
| 52 | + |
| 53 | + if (tpcdsDataPath.nonEmpty) { |
| 54 | + val nonExistentTables = tableNames.filterNot { tableName => |
| 55 | + Files.exists(Paths.get(s"${tpcdsDataPath.get}/$tableName")) |
| 56 | + } |
| 57 | + if (nonExistentTables.nonEmpty) { |
| 58 | + fail( |
| 59 | + s"Non-existent TPCDS table paths found in ${tpcdsDataPath.get}: " + |
| 60 | + nonExistentTables.mkString(", ")) |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + protected val baseResourcePath: String = { |
| 65 | + // use the same way as `SQLQueryTestSuite` to get the resource path |
| 66 | + getWorkspaceFilePath( |
| 67 | + "sql", |
| 68 | + "core", |
| 69 | + "src", |
| 70 | + "test", |
| 71 | + "resources", |
| 72 | + "tpcds-query-results").toFile.getAbsolutePath |
| 73 | + } |
| 74 | + |
| 75 | + override def createTable( |
| 76 | + spark: SparkSession, |
| 77 | + tableName: String, |
| 78 | + format: String = "parquet", |
| 79 | + options: scala.Seq[String]): Unit = { |
| 80 | + spark.sql(s""" |
| 81 | + |CREATE TABLE `$tableName` (${tableColumns(tableName)}) |
| 82 | + |USING $format |
| 83 | + |LOCATION '${tpcdsDataPath.get}/$tableName' |
| 84 | + |${options.mkString("\n")} |
| 85 | + """.stripMargin) |
| 86 | + } |
| 87 | + |
| 88 | + private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = { |
| 89 | + // This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins |
| 90 | + // than sort merge join. But in some queries DataFusion sort returns correct results |
| 91 | + // in terms of required sorting columns, but the results are not same as Spark in terms of |
| 92 | + // order of irrelevant columns. So, we need to sort the results for all joins. |
| 93 | + val shouldSortResults = true |
| 94 | + withSQLConf(conf.toSeq: _*) { |
| 95 | + try { |
| 96 | + val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) |
| 97 | + val queryString = query.trim |
| 98 | + val outputString = output.mkString("\n").replaceAll("\\s+$", "") |
| 99 | + if (regenGoldenFiles) { |
| 100 | + val goldenOutput = { |
| 101 | + s"-- Automatically generated by ${getClass.getSimpleName}\n\n" + |
| 102 | + "-- !query schema\n" + |
| 103 | + schema + "\n" + |
| 104 | + "-- !query output\n" + |
| 105 | + outputString + |
| 106 | + "\n" |
| 107 | + } |
| 108 | + val parent = goldenFile.getParentFile |
| 109 | + if (!parent.exists()) { |
| 110 | + assert(parent.mkdirs(), "Could not create directory: " + parent) |
| 111 | + } |
| 112 | + stringToFile(goldenFile, goldenOutput) |
| 113 | + } |
| 114 | + |
| 115 | + // Read back the golden file. |
| 116 | + val (expectedSchema, expectedOutput) = { |
| 117 | + val goldenOutput = fileToString(goldenFile) |
| 118 | + val segments = goldenOutput.split("-- !query.*\n") |
| 119 | + |
| 120 | + // query has 3 segments, plus the header |
| 121 | + assert( |
| 122 | + segments.size == 3, |
| 123 | + s"Expected 3 blocks in result file but got ${segments.size}. " + |
| 124 | + "Try regenerate the result files.") |
| 125 | + |
| 126 | + (segments(1).trim, segments(2).replaceAll("\\s+$", "")) |
| 127 | + } |
| 128 | + |
| 129 | + val notMatchedSchemaOutput = if (schema == emptySchema) { |
| 130 | + // There might be exception. See `handleExceptions`. |
| 131 | + s"Schema did not match\n$queryString\nOutput/Exception: $outputString" |
| 132 | + } else { |
| 133 | + s"Schema did not match\n$queryString" |
| 134 | + } |
| 135 | + |
| 136 | + assertResult(expectedSchema, notMatchedSchemaOutput) { |
| 137 | + schema |
| 138 | + } |
| 139 | + if (shouldSortResults) { |
| 140 | + val expectSorted = expectedOutput |
| 141 | + .split("\n") |
| 142 | + .sorted |
| 143 | + .map(_.trim) |
| 144 | + .mkString("\n") |
| 145 | + .replaceAll("\\s+$", "") |
| 146 | + val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "") |
| 147 | + assertResult(expectSorted, s"Result did not match\n$queryString") { |
| 148 | + outputSorted |
| 149 | + } |
| 150 | + } else { |
| 151 | + assertResult(expectedOutput, s"Result did not match\n$queryString") { |
| 152 | + outputString |
| 153 | + } |
| 154 | + } |
| 155 | + } catch { |
| 156 | + case e: Throwable => |
| 157 | + val configs = conf.map { case (k, v) => |
| 158 | + s"$k=$v" |
| 159 | + } |
| 160 | + throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}") |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + val sortMergeJoinConf: Map[String, String] = Map( |
| 166 | + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| 167 | + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") |
| 168 | + |
| 169 | + val broadcastHashJoinConf: Map[String, String] = Map( |
| 170 | + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760") |
| 171 | + |
| 172 | + val shuffledHashJoinConf: Map[String, String] = Map( |
| 173 | + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", |
| 174 | + "spark.sql.join.forceApplyShuffledHashJoin" -> "true") |
| 175 | + |
| 176 | + val allJoinConfCombinations: Seq[Map[String, String]] = |
| 177 | + Seq(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf) |
| 178 | + |
| 179 | + val joinConfs: Seq[Map[String, String]] = if (regenGoldenFiles) { |
| 180 | + require( |
| 181 | + !sys.env.contains("SPARK_TPCDS_JOIN_CONF"), |
| 182 | + "'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'") |
| 183 | + Seq(sortMergeJoinConf) |
| 184 | + } else { |
| 185 | + sys.env |
| 186 | + .get("SPARK_TPCDS_JOIN_CONF") |
| 187 | + .map { s => |
| 188 | + val p = new java.util.Properties() |
| 189 | + p.load(new java.io.StringReader(s)) |
| 190 | + Seq(p.asScala.toMap) |
| 191 | + } |
| 192 | + .getOrElse(allJoinConfCombinations) |
| 193 | + } |
| 194 | + |
| 195 | + assert(joinConfs.nonEmpty) |
| 196 | + joinConfs.foreach(conf => |
| 197 | + require( |
| 198 | + allJoinConfCombinations.contains(conf), |
| 199 | + s"Join configurations [$conf] should be one of $allJoinConfCombinations")) |
| 200 | + |
| 201 | + if (tpcdsDataPath.nonEmpty) { |
| 202 | + tpcdsQueries.foreach { name => |
| 203 | + val queryString = resourceToString( |
| 204 | + s"tpcds/$name.sql", |
| 205 | + classLoader = Thread.currentThread().getContextClassLoader) |
| 206 | + test(name) { |
| 207 | + val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out") |
| 208 | + joinConfs.foreach { conf => |
| 209 | + System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368 |
| 210 | + runQuery(queryString, goldenFile, conf) |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + tpcdsQueriesV2_7_0.foreach { name => |
| 216 | + val queryString = resourceToString( |
| 217 | + s"tpcds-v2.7.0/$name.sql", |
| 218 | + classLoader = Thread.currentThread().getContextClassLoader) |
| 219 | + test(s"$name-v2.7") { |
| 220 | + val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out") |
| 221 | + joinConfs.foreach { conf => |
| 222 | + System.gc() // SPARK-37368 |
| 223 | + runQuery(queryString, goldenFile, conf) |
| 224 | + } |
| 225 | + } |
| 226 | + } |
| 227 | + } else { |
| 228 | + ignore("skipped because env `SPARK_TPCDS_DATA` is not set") {} |
| 229 | + } |
| 230 | +} |
0 commit comments