Skip to content

Commit a93a260

Browse files
author
Andrew Or
committed
TakeOrderedAndProject + Sample
1 parent 10fc109 commit a93a260

File tree

4 files changed

+82
-73
lines changed

4 files changed

+82
-73
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
import java.util.Random
21-
2220
import org.apache.spark.sql.SQLConf
2321
import org.apache.spark.sql.catalyst.InternalRow
2422
import org.apache.spark.sql.catalyst.expressions.Attribute
25-
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
23+
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler, RandomSampler}
24+
2625

2726
/**
2827
* Sample the dataset.
@@ -51,18 +50,15 @@ case class SampleNode(
5150

5251
override def open(): Unit = {
5352
child.open()
54-
val (sampler, _seed) = if (withReplacement) {
55-
val random = new Random(seed)
53+
val sampler =
54+
if (withReplacement) {
5655
// Disable gap sampling since the gap sampling method buffers two rows internally,
5756
// requiring us to copy the row, which is more expensive than the random number generator.
58-
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
59-
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
60-
// of DataFrame
61-
random.nextLong())
57+
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
6258
} else {
63-
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
59+
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
6460
}
65-
sampler.setSeed(_seed)
61+
sampler.setSeed(seed)
6662
iterator = sampler.sample(child.asIterator)
6763
}
6864

sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,32 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
class SampleNodeSuite extends LocalNodeTest {
20+
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
21+
2122

22-
import testImplicits._
23+
class SampleNodeSuite extends LocalNodeTest {
2324

2425
private def testSample(withReplacement: Boolean): Unit = {
25-
test(s"withReplacement: $withReplacement") {
26-
val seed = 0L
27-
val input = sqlContext.sparkContext.
28-
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
29-
toDF("key", "value")
30-
checkAnswer(
31-
input,
32-
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
33-
input.sample(withReplacement, 0.3, seed).collect()
34-
)
26+
val seed = 0L
27+
val lowerb = 0.0
28+
val upperb = 0.3
29+
val maybeOut = if (withReplacement) "" else "out"
30+
test(s"with$maybeOut replacement") {
31+
val inputData = (1 to 1000).map { i => (i, i) }.toArray
32+
val inputNode = new DummyNode(kvIntAttributes, inputData)
33+
val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode)
34+
val sampler =
35+
if (withReplacement) {
36+
new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false)
37+
} else {
38+
new BernoulliCellSampler[(Int, Int)](lowerb, upperb)
39+
}
40+
sampler.setSeed(seed)
41+
val expectedOutput = sampler.sample(inputData.iterator).toArray
42+
val actualOutput = sampleNode.collect().map { case row =>
43+
(row.getInt(0), row.getInt(1))
44+
}
45+
assert(actualOutput === expectedOutput)
3546
}
3647
}
3748

sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,37 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
import org.apache.spark.sql.Column
21-
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
20+
import scala.util.Random
2221

23-
class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.expressions.SortOrder
2424

25-
import testImplicits._
2625

27-
private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
28-
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
29-
col.expr match {
30-
case expr: SortOrder =>
31-
expr
32-
case expr: Expression =>
33-
SortOrder(expr, Ascending)
34-
}
35-
}
36-
sortOrder
37-
}
26+
class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
3827

39-
private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
40-
val testCaseName = if (desc) "desc" else "asc"
41-
test(testCaseName) {
42-
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
43-
val sortColumn = if (desc) input.col("key").desc else input.col("key")
44-
checkAnswer(
45-
input,
46-
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
47-
input.sort(sortColumn).limit(5).collect()
48-
)
28+
private def testTakeOrderedAndProject(desc: Boolean): Unit = {
29+
val limit = 10
30+
val ascOrDesc = if (desc) "desc" else "asc"
31+
// TODO: re-enable me once TakeOrderedAndProjectNode can return things in sorted order.
32+
// This test is ignored because the node currently just returns the items in the order
33+
// maintained by the underlying min / max heap, but we expect sorted order.
34+
ignore(ascOrDesc) {
35+
val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray
36+
val inputNode = new DummyNode(kvIntAttributes, inputData)
37+
val firstColumn = inputNode.output(0)
38+
val sortDirection = if (desc) Descending else Ascending
39+
val sortOrder = SortOrder(firstColumn, sortDirection)
40+
val takeOrderAndProjectNode = new TakeOrderedAndProjectNode(
41+
conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode)
42+
val expectedOutput = inputData
43+
.map { case (k, _) => k }
44+
.sortBy { k => k * (if (desc) -1 else 1) }
45+
.take(limit)
46+
val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) }
47+
assert(actualOutput === expectedOutput)
4948
}
5049
}
5150

52-
testTakeOrderedAndProjectNode(desc = false)
53-
testTakeOrderedAndProjectNode(desc = true)
51+
testTakeOrderedAndProject(desc = false)
52+
testTakeOrderedAndProject(desc = true)
5453
}

sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,39 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
import org.apache.spark.sql.test.SharedSQLContext
2120

22-
class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
21+
class UnionNodeSuite extends LocalNodeTest {
2322

24-
test("basic") {
25-
checkAnswer2(
26-
testData,
27-
testData,
28-
(node1, node2) => UnionNode(conf, Seq(node1, node2)),
29-
testData.unionAll(testData).collect()
30-
)
23+
private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = {
24+
val inputNodes = inputData.map { data =>
25+
new DummyNode(kvIntAttributes, data)
26+
}
27+
val unionNode = new UnionNode(conf, inputNodes)
28+
val expectedOutput = inputData.flatten
29+
val actualOutput = unionNode.collect().map { case row =>
30+
(row.getInt(0), row.getInt(1))
31+
}
32+
assert(actualOutput === expectedOutput)
3133
}
3234

3335
test("empty") {
34-
checkAnswer2(
35-
emptyTestData,
36-
emptyTestData,
37-
(node1, node2) => UnionNode(conf, Seq(node1, node2)),
38-
emptyTestData.unionAll(emptyTestData).collect()
39-
)
36+
testUnion(Seq(Array.empty))
37+
testUnion(Seq(Array.empty, Array.empty))
38+
}
39+
40+
test("self") {
41+
val data = (1 to 100).map { i => (i, i) }.toArray
42+
testUnion(Seq(data))
43+
testUnion(Seq(data, data))
44+
testUnion(Seq(data, data, data))
4045
}
4146

42-
test("complicated union") {
43-
val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData,
44-
emptyTestData, emptyTestData, testData, emptyTestData)
45-
doCheckAnswer(
46-
dfs,
47-
nodes => UnionNode(conf, nodes),
48-
dfs.reduce(_.unionAll(_)).collect()
49-
)
47+
test("basic") {
48+
val zero = Array.empty[(Int, Int)]
49+
val one = (1 to 100).map { i => (i, i) }.toArray
50+
val two = (50 to 150).map { i => (i, i) }.toArray
51+
val three = (800 to 900).map { i => (i, i) }.toArray
52+
testUnion(Seq(zero, one, two, three))
5053
}
5154

5255
}

0 commit comments

Comments
 (0)