Skip to content

Commit a80f9b0

Browse files
committed
Merge pull request #4 from marmbrus/pr/6885
Add simple resolver
2 parents c60a44d + d9ab1e4 commit a80f9b0

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Ascending, SortOrder}
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
23+
import org.apache.spark.sql.test.TestSQLContext
2124
import org.apache.spark.sql.types.{IntegerType, StringType}
2225

2326
class SortSuite extends SparkPlanTest {
27+
import TestSQLContext.implicits.localSeqToDataFrameHolder
2428

2529
test("basic sorting using ExternalSort") {
2630

@@ -30,16 +34,14 @@ class SortSuite extends SparkPlanTest {
3034
("World", 8)
3135
)
3236

33-
val sortOrder = Seq(
34-
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
35-
SortOrder(BoundReference(1, IntegerType, nullable = false), Ascending)
36-
)
37-
3837
checkAnswer(
39-
input,
40-
(child: SparkPlan) => new ExternalSort(sortOrder, global = false, child),
41-
input.sorted
42-
)
38+
input.toDF("a", "b"),
39+
ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
40+
input.sorted)
4341

42+
checkAnswer(
43+
input.toDF("a", "b"),
44+
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
45+
input.sortBy(t => (t._2, t._1)))
4446
}
4547
}

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ import scala.util.control.NonFatal
2121
import scala.reflect.runtime.universe.TypeTag
2222

2323
import org.apache.spark.SparkFunSuite
24+
25+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
26+
import org.apache.spark.sql.catalyst.expressions.BoundReference
27+
import org.apache.spark.sql.catalyst.util._
28+
2429
import org.apache.spark.sql.test.TestSQLContext
2530
import org.apache.spark.sql.{Row, DataFrame}
26-
import org.apache.spark.sql.catalyst.util._
2731

2832
/**
2933
* Base class for writing tests for individual physical operators. For an example of how this
@@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite {
4852
}
4953
}
5054

55+
/**
56+
* Runs the plan and makes sure the answer matches the expected result.
57+
* @param input the input data to be used.
58+
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
59+
* the physical operator that's being tested.
60+
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
61+
*/
62+
protected def checkAnswer[A <: Product : TypeTag](
63+
input: DataFrame,
64+
planFunction: SparkPlan => SparkPlan,
65+
expectedAnswer: Seq[A]): Unit = {
66+
val expectedRows = expectedAnswer.map(Row.fromTuple)
67+
SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match {
68+
case Some(errorMessage) => fail(errorMessage)
69+
case None =>
70+
}
71+
}
72+
5173
/**
5274
* Runs the plan and makes sure the answer matches the expected result.
5375
* @param input the input data to be used.
@@ -87,6 +109,23 @@ object SparkPlanTest {
87109

88110
val outputPlan = planFunction(input.queryExecution.sparkPlan)
89111

112+
// A very simple resolver to make writing tests easier. In contrast to the real resolver
113+
// this is always case sensitive and does not try to handle scoping or complex type resolution.
114+
val resolvedPlan = outputPlan transform {
115+
case plan: SparkPlan =>
116+
val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
117+
case (a, i) =>
118+
(a.name, BoundReference(i, a.dataType, a.nullable))
119+
}.toMap
120+
121+
plan.transformExpressions {
122+
case UnresolvedAttribute(Seq(u)) =>
123+
inputMap.get(u).getOrElse {
124+
sys.error(s"Invalid Test: Cannot resolve $u given input ${inputMap}")
125+
}
126+
}
127+
}
128+
90129
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
91130
// Converts data to types that we can do equality comparison using Scala collections.
92131
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -105,7 +144,7 @@ object SparkPlanTest {
105144
}
106145

107146
val sparkAnswer: Seq[Row] = try {
108-
outputPlan.executeCollect().toSeq
147+
resolvedPlan.executeCollect().toSeq
109148
} catch {
110149
case NonFatal(e) =>
111150
val errorMessage =

0 commit comments

Comments
 (0)