Skip to content

Commit 35a19f3

Browse files
author
Andrew Or
committed
[SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext
Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or <andrew@databricks.com> Closes #8764 from andrewor14/sql-local-tests-cleanup.
1 parent 38700ea commit 35a19f3

17 files changed

+468
-636
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
2424
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
27-
import org.apache.spark.sql.catalyst.trees.TreeNode
27+
import org.apache.spark.sql.catalyst.plans.QueryPlan
2828
import org.apache.spark.sql.types.StructType
2929

3030
/**
@@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
3333
* Before consuming the iterator, open function must be called.
3434
* After consuming the iterator, close function must be called.
3535
*/
36-
abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
36+
abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {
3737

3838
protected val codegenEnabled: Boolean = conf.codegenEnabled
3939

4040
protected val unsafeEnabled: Boolean = conf.unsafeEnabled
4141

42-
lazy val schema: StructType = StructType.fromAttributes(output)
43-
4442
private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")
4543

46-
def output: Seq[Attribute]
47-
4844
/**
4945
* Called before open(). Prepare can be used to reserve memory needed. It must NOT consume
5046
* any input data.

sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
import java.util.Random
21-
2220
import org.apache.spark.sql.SQLConf
2321
import org.apache.spark.sql.catalyst.InternalRow
2422
import org.apache.spark.sql.catalyst.expressions.Attribute
2523
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
2624

25+
2726
/**
2827
* Sample the dataset.
2928
*
@@ -51,18 +50,15 @@ case class SampleNode(
5150

5251
override def open(): Unit = {
5352
child.open()
54-
val (sampler, _seed) = if (withReplacement) {
55-
val random = new Random(seed)
53+
val sampler =
54+
if (withReplacement) {
5655
// Disable gap sampling since the gap sampling method buffers two rows internally,
5756
// requiring us to copy the row, which is more expensive than the random number generator.
58-
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
59-
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
60-
// of DataFrame
61-
random.nextLong())
57+
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
6258
} else {
63-
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
59+
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
6460
}
65-
sampler.setSeed(_seed)
61+
sampler.setSeed(seed)
6662
iterator = sampler.sample(child.asIterator)
6763
}
6864

sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
5050
}
5151
// Close it eagerly since we don't need it.
5252
child.close()
53-
iterator = queue.iterator
53+
iterator = queue.toArray.sorted(ord).iterator
5454
}
5555

5656
override def next(): Boolean = {

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ object SparkPlanTest {
238238
outputPlan transform {
239239
case plan: SparkPlan =>
240240
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
241-
plan.transformExpressions {
241+
plan transformExpressions {
242242
case UnresolvedAttribute(Seq(u)) =>
243243
inputMap.getOrElse(u,
244244
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import org.apache.spark.sql.SQLConf
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.Attribute
23+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
24+
25+
/**
26+
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
27+
*/
28+
private[local] case class DummyNode(
29+
output: Seq[Attribute],
30+
relation: LocalRelation,
31+
conf: SQLConf)
32+
extends LocalNode(conf) {
33+
34+
import DummyNode._
35+
36+
private var index: Int = CLOSED
37+
private val input: Seq[InternalRow] = relation.data
38+
39+
def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
40+
this(output, LocalRelation.fromProduct(output, data), conf)
41+
}
42+
43+
def isOpen: Boolean = index != CLOSED
44+
45+
override def children: Seq[LocalNode] = Seq.empty
46+
47+
override def open(): Unit = {
48+
index = -1
49+
}
50+
51+
override def next(): Boolean = {
52+
index += 1
53+
index < input.size
54+
}
55+
56+
override def fetch(): InternalRow = {
57+
assert(index >= 0 && index < input.size)
58+
input(index)
59+
}
60+
61+
override def close(): Unit = {
62+
index = CLOSED
63+
}
64+
}
65+
66+
private object DummyNode {
67+
val CLOSED: Int = Int.MinValue
68+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,33 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
22+
2023
class ExpandNodeSuite extends LocalNodeTest {
2124

22-
import testImplicits._
23-
24-
test("expand") {
25-
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
26-
checkAnswer(
27-
input,
28-
node =>
29-
ExpandNode(conf, Seq(
30-
Seq(
31-
input.col("key") + input.col("value"), input.col("key") - input.col("value")
32-
).map(_.expr),
33-
Seq(
34-
input.col("key") * input.col("value"), input.col("key") / input.col("value")
35-
).map(_.expr)
36-
), node.output, node),
37-
Seq(
38-
(2, 0),
39-
(1, 1),
40-
(4, 0),
41-
(4, 1),
42-
(6, 0),
43-
(9, 1),
44-
(8, 0),
45-
(16, 1),
46-
(10, 0),
47-
(25, 1)
48-
).toDF().collect()
49-
)
25+
private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
26+
val inputNode = new DummyNode(kvIntAttributes, inputData)
27+
val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
28+
val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode)
29+
val resolvedNode = resolveExpressions(expandNode)
30+
val expectedOutput = {
31+
val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
32+
val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
33+
firstHalf ++ secondHalf
34+
}
35+
val actualOutput = resolvedNode.collect().map { case row =>
36+
(row.getInt(0), row.getInt(1))
37+
}
38+
assert(actualOutput.toSet === expectedOutput.toSet)
39+
}
40+
41+
test("empty") {
42+
testExpand()
5043
}
44+
45+
test("basic") {
46+
testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
47+
}
48+
5149
}

sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,29 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
import org.apache.spark.sql.test.SharedSQLContext
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
2121

22-
class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
2322

24-
test("basic") {
25-
val condition = (testData.col("key") % 2) === 0
26-
checkAnswer(
27-
testData,
28-
node => FilterNode(conf, condition.expr, node),
29-
testData.filter(condition).collect()
30-
)
23+
class FilterNodeSuite extends LocalNodeTest {
24+
25+
private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
26+
val cond = 'k % 2 === 0
27+
val inputNode = new DummyNode(kvIntAttributes, inputData)
28+
val filterNode = new FilterNode(conf, cond, inputNode)
29+
val resolvedNode = resolveExpressions(filterNode)
30+
val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
31+
val actualOutput = resolvedNode.collect().map { case row =>
32+
(row.getInt(0), row.getInt(1))
33+
}
34+
assert(actualOutput === expectedOutput)
3135
}
3236

3337
test("empty") {
34-
val condition = (emptyTestData.col("key") % 2) === 0
35-
checkAnswer(
36-
emptyTestData,
37-
node => FilterNode(conf, condition.expr, node),
38-
emptyTestData.filter(condition).collect()
39-
)
38+
testFilter()
39+
}
40+
41+
test("basic") {
42+
testFilter((1 to 100).map { i => (i, i) }.toArray)
4043
}
44+
4145
}

0 commit comments

Comments
 (0)