Skip to content

Commit 32c43c2

Browse files
committed
Added test for prediction
- Test predictOnValues for accuracy on a test stream
1 parent 217b5e9 commit 32c43c2

File tree

1 file changed

+65
-4
lines changed

1 file changed

+65
-4
lines changed

mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
4949
}
5050

5151
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
52-
test("streaming linear regression parameter accuracy") {
52+
test("parameter accuracy") {
5353

5454
val testDir = Files.createTempDir()
5555
val numBatches = 10
@@ -76,7 +76,6 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
7676

7777
ssc.stop(stopSparkContext=false)
7878

79-
System.clearProperty("spark.driver.port")
8079
Utils.deleteRecursively(testDir)
8180

8281
// check accuracy of final parameter estimates
@@ -91,7 +90,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
9190
}
9291

9392
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
94-
test("streaming linear regression parameter convergence") {
93+
test("parameter convergence") {
9594

9695
val testDir = Files.createTempDir()
9796
val batchDuration = Milliseconds(2000)
@@ -121,7 +120,6 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
121120

122121
ssc.stop(stopSparkContext=false)
123122

124-
System.clearProperty("spark.driver.port")
125123
Utils.deleteRecursively(testDir)
126124

127125
val deltas = history.drop(1).zip(history.dropRight(1))
@@ -132,4 +130,67 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
132130

133131
}
134132

133+
// Test predictions on a stream
134+
test("predictions") {
135+
136+
val trainDir = Files.createTempDir()
137+
val testDir = Files.createTempDir()
138+
val batchDuration = Milliseconds(1000)
139+
val numBatches = 10
140+
val nPoints = 100
141+
142+
val ssc = new StreamingContext(sc, batchDuration)
143+
val data = ssc.textFileStream(trainDir.toString).map(LabeledPoint.parse)
144+
val model = new StreamingLinearRegressionWithSGD()
145+
.setInitialWeights(Vectors.dense(0.0, 0.0))
146+
.setStepSize(0.1)
147+
.setNumIterations(50)
148+
149+
model.trainOn(data)
150+
151+
ssc.start()
152+
153+
// write training data to a file stream
154+
for (i <- 0 until numBatches) {
155+
val samples = LinearDataGenerator.generateLinearInput(
156+
0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
157+
val file = new File(trainDir, i.toString)
158+
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
159+
Thread.sleep(batchDuration.milliseconds)
160+
}
161+
162+
ssc.stop(stopSparkContext=false)
163+
164+
Utils.deleteRecursively(trainDir)
165+
166+
print(model.latestModel().weights.toArray.mkString(" "))
167+
print(model.latestModel().intercept)
168+
169+
val ssc2 = new StreamingContext(sc, batchDuration)
170+
val data2 = ssc2.textFileStream(testDir.toString).map(LabeledPoint.parse)
171+
172+
val history = new ArrayBuffer[Double](numBatches)
173+
val predictions = model.predictOnValues(data2.map(x => (x.label, x.features)))
174+
val errors = predictions.map(x => math.abs(x._1 - x._2))
175+
errors.foreachRDD(rdd => history.append(rdd.reduce(_+_) / nPoints.toDouble))
176+
177+
ssc2.start()
178+
179+
// write test data to a file stream
180+
181+
// make a function
182+
for (i <- 0 until numBatches) {
183+
val samples = LinearDataGenerator.generateLinearInput(
184+
0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
185+
val file = new File(testDir, i.toString)
186+
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
187+
Thread.sleep(batchDuration.milliseconds)
188+
}
189+
190+
println(history)
191+
192+
ssc2.stop(stopSparkContext=false)
193+
194+
}
195+
135196
}

0 commit comments

Comments
 (0)