Skip to content

Commit 335146e

Browse files
authored
test: Copy Spark TPCDSQueryTestSuite to CometTPCDSQueryTestSuite (apache#628)
1 parent eff2897 commit 335146e

File tree

3 files changed

+339
-1
lines changed

3 files changed

+339
-1
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql
21+
22+
import scala.util.control.NonFatal
23+
24+
import org.apache.spark.{SparkException, SparkThrowable}
25+
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
26+
import org.apache.spark.sql.catalyst.plans.logical._
27+
import org.apache.spark.sql.execution.HiveResult.hiveResultString
28+
import org.apache.spark.sql.execution.SQLExecution
29+
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase}
30+
import org.apache.spark.sql.types.StructType
31+
32+
trait CometSQLQueryTestHelper {
33+
34+
private val notIncludedMsg = "[not included in comparison]"
35+
private val clsName = this.getClass.getCanonicalName
36+
protected val emptySchema: String = StructType(Seq.empty).catalogString
37+
38+
protected def replaceNotIncludedMsg(line: String): String = {
39+
line
40+
.replaceAll("#\\d+", "#x")
41+
.replaceAll("plan_id=\\d+", "plan_id=x")
42+
.replaceAll(s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/")
43+
.replaceAll(s"file:[^\\s,]*$clsName", s"file:$notIncludedMsg/{warehouse_dir}")
44+
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
45+
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
46+
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
47+
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg")
48+
.replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds
49+
}
50+
51+
/** Executes a query and returns the result as (schema of the output, normalized output). */
52+
protected def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
53+
// Returns true if the plan is supposed to be sorted.
54+
def isSorted(plan: LogicalPlan): Boolean = plan match {
55+
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
56+
case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation |
57+
_: DescribeColumn =>
58+
true
59+
case PhysicalOperation(_, _, Sort(_, true, _)) => true
60+
case _ => plan.children.iterator.exists(isSorted)
61+
}
62+
63+
val df = session.sql(sql)
64+
val schema = df.schema.catalogString
65+
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
66+
val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) {
67+
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
68+
}
69+
70+
// If the output is not pre-sorted, sort it.
71+
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
72+
}
73+
74+
/**
75+
* This method handles exceptions occurred during query execution as they may need special care
76+
* to become comparable to the expected output.
77+
*
78+
* @param result
79+
* a function that returns a pair of schema and output
80+
*/
81+
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
82+
try {
83+
result
84+
} catch {
85+
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
86+
(emptySchema, Seq(e.getClass.getName, e.getMessage))
87+
case a: AnalysisException =>
88+
// Do not output the logical plan tree which contains expression IDs.
89+
// Also implement a crude way of masking expression IDs in the error message
90+
// with a generic pattern "###".
91+
val msg = a.getMessage
92+
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
93+
case s: SparkException if s.getCause != null =>
94+
// For a runtime exception, it is hard to match because its message contains
95+
// information of stage, task ID, etc.
96+
// To make result matching simpler, here we match the cause of the exception if it exists.
97+
s.getCause match {
98+
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
99+
(emptySchema, Seq(e.getClass.getName, e.getMessage))
100+
case cause =>
101+
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
102+
}
103+
case NonFatal(e) =>
104+
// If there is an exception, put the exception class followed by the message.
105+
(emptySchema, Seq(e.getClass.getName, e.getMessage))
106+
}
107+
}
108+
}

spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class CometTPCDSQuerySuite
145145
override val tpcdsQueries: Seq[String] =
146146
tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)
147147
}
148-
with TPCDSQueryTestSuite
148+
with CometTPCDSQueryTestSuite
149149
with ShimCometTPCDSQuerySuite {
150150
override def sparkConf: SparkConf = {
151151
val conf = super.sparkConf
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql
21+
22+
import java.io.File
23+
import java.nio.file.{Files, Paths}
24+
25+
import scala.collection.JavaConverters._
26+
27+
import org.apache.spark.{SparkConf, SparkContext}
28+
import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile}
29+
import org.apache.spark.sql.internal.SQLConf
30+
import org.apache.spark.sql.test.TestSparkSession
31+
32+
/**
33+
* Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we
34+
* copy Spark `TPCDSQueryTestSuite`.
35+
*/
36+
class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper {
37+
38+
private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA")
39+
40+
private val regenGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
41+
42+
// To make output results deterministic
43+
override protected def sparkConf: SparkConf = super.sparkConf
44+
.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
45+
46+
protected override def createSparkSession: TestSparkSession = {
47+
new TestSparkSession(new SparkContext("local[1]", this.getClass.getSimpleName, sparkConf))
48+
}
49+
50+
// We use SF=1 table data here, so we cannot use SF=100 stats
51+
protected override val injectStats: Boolean = false
52+
53+
if (tpcdsDataPath.nonEmpty) {
54+
val nonExistentTables = tableNames.filterNot { tableName =>
55+
Files.exists(Paths.get(s"${tpcdsDataPath.get}/$tableName"))
56+
}
57+
if (nonExistentTables.nonEmpty) {
58+
fail(
59+
s"Non-existent TPCDS table paths found in ${tpcdsDataPath.get}: " +
60+
nonExistentTables.mkString(", "))
61+
}
62+
}
63+
64+
protected val baseResourcePath: String = {
65+
// use the same way as `SQLQueryTestSuite` to get the resource path
66+
getWorkspaceFilePath(
67+
"sql",
68+
"core",
69+
"src",
70+
"test",
71+
"resources",
72+
"tpcds-query-results").toFile.getAbsolutePath
73+
}
74+
75+
override def createTable(
76+
spark: SparkSession,
77+
tableName: String,
78+
format: String = "parquet",
79+
options: scala.Seq[String]): Unit = {
80+
spark.sql(s"""
81+
|CREATE TABLE `$tableName` (${tableColumns(tableName)})
82+
|USING $format
83+
|LOCATION '${tpcdsDataPath.get}/$tableName'
84+
|${options.mkString("\n")}
85+
""".stripMargin)
86+
}
87+
88+
private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = {
89+
// This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins
90+
// than sort merge join. But in some queries DataFusion sort returns correct results
91+
// in terms of required sorting columns, but the results are not same as Spark in terms of
92+
// order of irrelevant columns. So, we need to sort the results for all joins.
93+
val shouldSortResults = true
94+
withSQLConf(conf.toSeq: _*) {
95+
try {
96+
val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
97+
val queryString = query.trim
98+
val outputString = output.mkString("\n").replaceAll("\\s+$", "")
99+
if (regenGoldenFiles) {
100+
val goldenOutput = {
101+
s"-- Automatically generated by ${getClass.getSimpleName}\n\n" +
102+
"-- !query schema\n" +
103+
schema + "\n" +
104+
"-- !query output\n" +
105+
outputString +
106+
"\n"
107+
}
108+
val parent = goldenFile.getParentFile
109+
if (!parent.exists()) {
110+
assert(parent.mkdirs(), "Could not create directory: " + parent)
111+
}
112+
stringToFile(goldenFile, goldenOutput)
113+
}
114+
115+
// Read back the golden file.
116+
val (expectedSchema, expectedOutput) = {
117+
val goldenOutput = fileToString(goldenFile)
118+
val segments = goldenOutput.split("-- !query.*\n")
119+
120+
// query has 3 segments, plus the header
121+
assert(
122+
segments.size == 3,
123+
s"Expected 3 blocks in result file but got ${segments.size}. " +
124+
"Try regenerate the result files.")
125+
126+
(segments(1).trim, segments(2).replaceAll("\\s+$", ""))
127+
}
128+
129+
val notMatchedSchemaOutput = if (schema == emptySchema) {
130+
// There might be exception. See `handleExceptions`.
131+
s"Schema did not match\n$queryString\nOutput/Exception: $outputString"
132+
} else {
133+
s"Schema did not match\n$queryString"
134+
}
135+
136+
assertResult(expectedSchema, notMatchedSchemaOutput) {
137+
schema
138+
}
139+
if (shouldSortResults) {
140+
val expectSorted = expectedOutput
141+
.split("\n")
142+
.sorted
143+
.map(_.trim)
144+
.mkString("\n")
145+
.replaceAll("\\s+$", "")
146+
val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "")
147+
assertResult(expectSorted, s"Result did not match\n$queryString") {
148+
outputSorted
149+
}
150+
} else {
151+
assertResult(expectedOutput, s"Result did not match\n$queryString") {
152+
outputString
153+
}
154+
}
155+
} catch {
156+
case e: Throwable =>
157+
val configs = conf.map { case (k, v) =>
158+
s"$k=$v"
159+
}
160+
throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}")
161+
}
162+
}
163+
}
164+
165+
val sortMergeJoinConf: Map[String, String] = Map(
166+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
167+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true")
168+
169+
val broadcastHashJoinConf: Map[String, String] = Map(
170+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760")
171+
172+
val shuffledHashJoinConf: Map[String, String] = Map(
173+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
174+
"spark.sql.join.forceApplyShuffledHashJoin" -> "true")
175+
176+
val allJoinConfCombinations: Seq[Map[String, String]] =
177+
Seq(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf)
178+
179+
val joinConfs: Seq[Map[String, String]] = if (regenGoldenFiles) {
180+
require(
181+
!sys.env.contains("SPARK_TPCDS_JOIN_CONF"),
182+
"'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'")
183+
Seq(sortMergeJoinConf)
184+
} else {
185+
sys.env
186+
.get("SPARK_TPCDS_JOIN_CONF")
187+
.map { s =>
188+
val p = new java.util.Properties()
189+
p.load(new java.io.StringReader(s))
190+
Seq(p.asScala.toMap)
191+
}
192+
.getOrElse(allJoinConfCombinations)
193+
}
194+
195+
assert(joinConfs.nonEmpty)
196+
joinConfs.foreach(conf =>
197+
require(
198+
allJoinConfCombinations.contains(conf),
199+
s"Join configurations [$conf] should be one of $allJoinConfCombinations"))
200+
201+
if (tpcdsDataPath.nonEmpty) {
202+
tpcdsQueries.foreach { name =>
203+
val queryString = resourceToString(
204+
s"tpcds/$name.sql",
205+
classLoader = Thread.currentThread().getContextClassLoader)
206+
test(name) {
207+
val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out")
208+
joinConfs.foreach { conf =>
209+
System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368
210+
runQuery(queryString, goldenFile, conf)
211+
}
212+
}
213+
}
214+
215+
tpcdsQueriesV2_7_0.foreach { name =>
216+
val queryString = resourceToString(
217+
s"tpcds-v2.7.0/$name.sql",
218+
classLoader = Thread.currentThread().getContextClassLoader)
219+
test(s"$name-v2.7") {
220+
val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out")
221+
joinConfs.foreach { conf =>
222+
System.gc() // SPARK-37368
223+
runQuery(queryString, goldenFile, conf)
224+
}
225+
}
226+
}
227+
} else {
228+
ignore("skipped because env `SPARK_TPCDS_DATA` is not set") {}
229+
}
230+
}

0 commit comments

Comments
 (0)