@@ -21,9 +21,13 @@ import scala.util.control.NonFatal
21
21
import scala .reflect .runtime .universe .TypeTag
22
22
23
23
import org .apache .spark .SparkFunSuite
24
+
25
+ import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
26
+ import org .apache .spark .sql .catalyst .expressions .BoundReference
27
+ import org .apache .spark .sql .catalyst .util ._
28
+
24
29
import org .apache .spark .sql .test .TestSQLContext
25
30
import org .apache .spark .sql .{Row , DataFrame }
26
- import org .apache .spark .sql .catalyst .util ._
27
31
28
32
/**
29
33
* Base class for writing tests for individual physical operators. For an example of how this
@@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite {
48
52
}
49
53
}
50
54
55
+ /**
56
+ * Runs the plan and makes sure the answer matches the expected result.
57
+ * @param input the input data to be used.
58
+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
59
+ * the physical operator that's being tested.
60
+ * @param expectedAnswer the expected result in a [[Seq ]] of [[Product ]]s.
61
+ */
62
+ protected def checkAnswer [A <: Product : TypeTag ](
63
+ input : DataFrame ,
64
+ planFunction : SparkPlan => SparkPlan ,
65
+ expectedAnswer : Seq [A ]): Unit = {
66
+ val expectedRows = expectedAnswer.map(Row .fromTuple)
67
+ SparkPlanTest .checkAnswer(input, planFunction, expectedRows) match {
68
+ case Some (errorMessage) => fail(errorMessage)
69
+ case None =>
70
+ }
71
+ }
72
+
51
73
/**
52
74
* Runs the plan and makes sure the answer matches the expected result.
53
75
* @param input the input data to be used.
@@ -87,6 +109,23 @@ object SparkPlanTest {
87
109
88
110
val outputPlan = planFunction(input.queryExecution.sparkPlan)
89
111
112
+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
113
+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
114
+ val resolvedPlan = outputPlan transform {
115
+ case plan : SparkPlan =>
116
+ val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
117
+ case (a, i) =>
118
+ (a.name, BoundReference (i, a.dataType, a.nullable))
119
+ }.toMap
120
+
121
+ plan.transformExpressions {
122
+ case UnresolvedAttribute (Seq (u)) =>
123
+ inputMap.get(u).getOrElse {
124
+ sys.error(s " Invalid Test: Cannot resolve $u given input ${inputMap}" )
125
+ }
126
+ }
127
+ }
128
+
90
129
def prepareAnswer (answer : Seq [Row ]): Seq [Row ] = {
91
130
// Converts data to types that we can do equality comparison using Scala collections.
92
131
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -105,7 +144,7 @@ object SparkPlanTest {
105
144
}
106
145
107
146
val sparkAnswer : Seq [Row ] = try {
108
- outputPlan .executeCollect().toSeq
147
+ resolvedPlan .executeCollect().toSeq
109
148
} catch {
110
149
case NonFatal (e) =>
111
150
val errorMessage =
0 commit comments