Skip to content

Commit 50eb0bf

Browse files
committed
Refactored tests to use streaming test tools
- Made mllib depend on tests from streaming - Rewrote all streamingLR tests to use the setupStreams & runStreams functions
1 parent 32c43c2 commit 50eb0bf

File tree

3 files changed

+63
-103
lines changed

3 files changed

+63
-103
lines changed

mllib/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@
9191
<artifactId>junit-interface</artifactId>
9292
<scope>test</scope>
9393
</dependency>
94+
<dependency>
95+
<groupId>org.apache.spark</groupId>
96+
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
97+
<version>${project.version}</version>
98+
<type>test-jar</type>
99+
<scope>test</scope>
100+
</dependency>
94101
</dependencies>
95102
<profiles>
96103
<profile>

mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala

Lines changed: 55 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import java.io.File
21-
import java.nio.charset.Charset
22-
2320
import scala.collection.mutable.ArrayBuffer
2421

25-
import com.google.common.io.Files
2622
import org.scalatest.FunSuite
2723

2824
import org.apache.spark.mllib.linalg.Vectors
29-
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
30-
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
31-
import org.apache.spark.util.Utils
25+
import org.apache.spark.mllib.util.LinearDataGenerator
26+
import org.apache.spark.streaming.dstream.DStream
27+
import org.apache.spark.streaming.TestSuiteBase
28+
29+
class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
3230

33-
class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
31+
// use longer wait time to ensure job completion
32+
override def maxWaitTimeMillis = 20000
3433

3534
// Assert that two values are equal within tolerance epsilon
3635
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
@@ -51,32 +50,24 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
5150
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
5251
test("parameter accuracy") {
5352

54-
val testDir = Files.createTempDir()
55-
val numBatches = 10
56-
val batchDuration = Milliseconds(1000)
57-
val ssc = new StreamingContext(sc, batchDuration)
58-
val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse)
53+
// create model
5954
val model = new StreamingLinearRegressionWithSGD()
6055
.setInitialWeights(Vectors.dense(0.0, 0.0))
6156
.setStepSize(0.1)
62-
.setNumIterations(50)
63-
64-
model.trainOn(data)
57+
.setNumIterations(25)
6558

66-
ssc.start()
67-
68-
// write data to a file stream
69-
for (i <- 0 until numBatches) {
70-
val samples = LinearDataGenerator.generateLinearInput(
71-
0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
72-
val file = new File(testDir, i.toString)
73-
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
74-
Thread.sleep(batchDuration.milliseconds)
59+
// generate sequence of simulated data
60+
val numBatches = 10
61+
val input = (0 until numBatches).map { i =>
62+
LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
7563
}
7664

77-
ssc.stop(stopSparkContext=false)
78-
79-
Utils.deleteRecursively(testDir)
65+
// apply model training to input stream
66+
val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
67+
model.trainOn(inputDStream)
68+
inputDStream.count()
69+
})
70+
runStreams(ssc, numBatches, numBatches)
8071

8172
// check accuracy of final parameter estimates
8273
assertEqual(model.latestModel().intercept, 0.0, 0.1)
@@ -92,36 +83,31 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
9283
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
9384
test("parameter convergence") {
9485

95-
val testDir = Files.createTempDir()
96-
val batchDuration = Milliseconds(2000)
97-
val ssc = new StreamingContext(sc, batchDuration)
98-
val numBatches = 5
99-
val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse)
86+
// create model
10087
val model = new StreamingLinearRegressionWithSGD()
10188
.setInitialWeights(Vectors.dense(0.0))
10289
.setStepSize(0.1)
103-
.setNumIterations(50)
104-
105-
model.trainOn(data)
90+
.setNumIterations(25)
10691

107-
ssc.start()
108-
109-
// write data to a file stream
110-
val history = new ArrayBuffer[Double](numBatches)
111-
for (i <- 0 until numBatches) {
112-
val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
113-
val file = new File(testDir, i.toString)
114-
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
115-
Thread.sleep(batchDuration.milliseconds)
116-
// wait an extra few seconds to make sure the update finishes before new data arrive
117-
Thread.sleep(4000)
118-
history.append(math.abs(model.latestModel().weights(0) - 10.0))
92+
// generate sequence of simulated data
93+
val numBatches = 10
94+
val input = (0 until numBatches).map { i =>
95+
LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
11996
}
12097

121-
ssc.stop(stopSparkContext=false)
98+
// create buffer to store intermediate fits
99+
val history = new ArrayBuffer[Double](numBatches)
122100

123-
Utils.deleteRecursively(testDir)
101+
// apply model training to input stream, storing the intermediate results
102+
// (we add a count to ensure the result is a DStream)
103+
val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
104+
model.trainOn(inputDStream)
105+
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
106+
inputDStream.count()
107+
})
108+
runStreams(ssc, numBatches, numBatches)
124109

110+
// compute change in error
125111
val deltas = history.drop(1).zip(history.dropRight(1))
126112
// check error stability (it always either shrinks, or increases with small tol)
127113
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
@@ -133,63 +119,30 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
133119
// Test predictions on a stream
134120
test("predictions") {
135121

136-
val trainDir = Files.createTempDir()
137-
val testDir = Files.createTempDir()
138-
val batchDuration = Milliseconds(1000)
139-
val numBatches = 10
140-
val nPoints = 100
141-
142-
val ssc = new StreamingContext(sc, batchDuration)
143-
val data = ssc.textFileStream(trainDir.toString).map(LabeledPoint.parse)
122+
// create model initialized with true weights
144123
val model = new StreamingLinearRegressionWithSGD()
145-
.setInitialWeights(Vectors.dense(0.0, 0.0))
124+
.setInitialWeights(Vectors.dense(10.0, 10.0))
146125
.setStepSize(0.1)
147-
.setNumIterations(50)
126+
.setNumIterations(25)
148127

149-
model.trainOn(data)
150-
151-
ssc.start()
152-
153-
// write training data to a file stream
154-
for (i <- 0 until numBatches) {
155-
val samples = LinearDataGenerator.generateLinearInput(
156-
0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
157-
val file = new File(trainDir, i.toString)
158-
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
159-
Thread.sleep(batchDuration.milliseconds)
160-
}
161-
162-
ssc.stop(stopSparkContext=false)
163-
164-
Utils.deleteRecursively(trainDir)
165-
166-
print(model.latestModel().weights.toArray.mkString(" "))
167-
print(model.latestModel().intercept)
168-
169-
val ssc2 = new StreamingContext(sc, batchDuration)
170-
val data2 = ssc2.textFileStream(testDir.toString).map(LabeledPoint.parse)
171-
172-
val history = new ArrayBuffer[Double](numBatches)
173-
val predictions = model.predictOnValues(data2.map(x => (x.label, x.features)))
174-
val errors = predictions.map(x => math.abs(x._1 - x._2))
175-
errors.foreachRDD(rdd => history.append(rdd.reduce(_+_) / nPoints.toDouble))
176-
177-
ssc2.start()
178-
179-
// write test data to a file stream
180-
181-
// make a function
182-
for (i <- 0 until numBatches) {
183-
val samples = LinearDataGenerator.generateLinearInput(
184-
0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
185-
val file = new File(testDir, i.toString)
186-
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
187-
Thread.sleep(batchDuration.milliseconds)
128+
// generate sequence of simulated data for testing
129+
val numBatches = 10
130+
val nPoints = 100
131+
val testInput = (0 until numBatches).map { i =>
132+
LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
188133
}
189134

190-
println(history)
191-
192-
ssc2.stop(stopSparkContext=false)
135+
// apply model predictions to test stream
136+
val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
137+
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
138+
})
139+
// collect the output as (true, estimated) tuples
140+
val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
141+
142+
// compute the mean absolute error and check that it's always less than 0.1
143+
val errors = output.map(batch => batch.map(
144+
p => math.abs(p._1 - p._2)).reduce(_+_) / nPoints.toDouble)
145+
assert(errors.forall(x => x <= 0.1))
193146

194147
}
195148

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
242242
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
243243

244244
// Get the output buffer
245-
val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
245+
val outputStream = ssc.graph.getOutputStreams.filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]).head.asInstanceOf[TestOutputStreamWithPartitions[V]]
246246
val output = outputStream.output
247247

248248
try {

0 commit comments

Comments
 (0)