17
17
18
18
package org .apache .spark .mllib .regression
19
19
20
- import java .io .File
21
- import java .nio .charset .Charset
22
-
23
20
import scala .collection .mutable .ArrayBuffer
24
21
25
- import com .google .common .io .Files
26
22
import org .scalatest .FunSuite
27
23
28
24
import org .apache .spark .mllib .linalg .Vectors
29
- import org .apache .spark .mllib .util .{LinearDataGenerator , LocalSparkContext }
30
- import org .apache .spark .streaming .{Milliseconds , StreamingContext }
31
- import org .apache .spark .util .Utils
25
+ import org .apache .spark .mllib .util .LinearDataGenerator
26
+ import org .apache .spark .streaming .dstream .DStream
27
+ import org .apache .spark .streaming .TestSuiteBase
28
+
29
+ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
32
30
33
- class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
31
+ // use longer wait time to ensure job completion
32
+ override def maxWaitTimeMillis = 20000
34
33
35
34
// Assert that two values are equal within tolerance epsilon
36
35
def assertEqual (v1 : Double , v2 : Double , epsilon : Double ) {
@@ -51,32 +50,24 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
51
50
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
52
51
test(" parameter accuracy" ) {
53
52
54
- val testDir = Files .createTempDir()
55
- val numBatches = 10
56
- val batchDuration = Milliseconds (1000 )
57
- val ssc = new StreamingContext (sc, batchDuration)
58
- val data = ssc.textFileStream(testDir.toString).map(LabeledPoint .parse)
53
+ // create model
59
54
val model = new StreamingLinearRegressionWithSGD ()
60
55
.setInitialWeights(Vectors .dense(0.0 , 0.0 ))
61
56
.setStepSize(0.1 )
62
- .setNumIterations(50 )
63
-
64
- model.trainOn(data)
57
+ .setNumIterations(25 )
65
58
66
- ssc.start()
67
-
68
- // write data to a file stream
69
- for (i <- 0 until numBatches) {
70
- val samples = LinearDataGenerator .generateLinearInput(
71
- 0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
72
- val file = new File (testDir, i.toString)
73
- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
74
- Thread .sleep(batchDuration.milliseconds)
59
+ // generate sequence of simulated data
60
+ val numBatches = 10
61
+ val input = (0 until numBatches).map { i =>
62
+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
75
63
}
76
64
77
- ssc.stop(stopSparkContext= false )
78
-
79
- Utils .deleteRecursively(testDir)
65
+ // apply model training to input stream
66
+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
67
+ model.trainOn(inputDStream)
68
+ inputDStream.count()
69
+ })
70
+ runStreams(ssc, numBatches, numBatches)
80
71
81
72
// check accuracy of final parameter estimates
82
73
assertEqual(model.latestModel().intercept, 0.0 , 0.1 )
@@ -92,36 +83,31 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
92
83
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
93
84
test(" parameter convergence" ) {
94
85
95
- val testDir = Files .createTempDir()
96
- val batchDuration = Milliseconds (2000 )
97
- val ssc = new StreamingContext (sc, batchDuration)
98
- val numBatches = 5
99
- val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint .parse)
86
+ // create model
100
87
val model = new StreamingLinearRegressionWithSGD ()
101
88
.setInitialWeights(Vectors .dense(0.0 ))
102
89
.setStepSize(0.1 )
103
- .setNumIterations(50 )
104
-
105
- model.trainOn(data)
90
+ .setNumIterations(25 )
106
91
107
- ssc.start()
108
-
109
- // write data to a file stream
110
- val history = new ArrayBuffer [Double ](numBatches)
111
- for (i <- 0 until numBatches) {
112
- val samples = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
113
- val file = new File (testDir, i.toString)
114
- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
115
- Thread .sleep(batchDuration.milliseconds)
116
- // wait an extra few seconds to make sure the update finishes before new data arrive
117
- Thread .sleep(4000 )
118
- history.append(math.abs(model.latestModel().weights(0 ) - 10.0 ))
92
+ // generate sequence of simulated data
93
+ val numBatches = 10
94
+ val input = (0 until numBatches).map { i =>
95
+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
119
96
}
120
97
121
- ssc.stop(stopSparkContext= false )
98
+ // create buffer to store intermediate fits
99
+ val history = new ArrayBuffer [Double ](numBatches)
122
100
123
- Utils .deleteRecursively(testDir)
101
+ // apply model training to input stream, storing the intermediate results
102
+ // (we add a count to ensure the result is a DStream)
103
+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
104
+ model.trainOn(inputDStream)
105
+ inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0 ) - 10.0 )))
106
+ inputDStream.count()
107
+ })
108
+ runStreams(ssc, numBatches, numBatches)
124
109
110
+ // compute change in error
125
111
val deltas = history.drop(1 ).zip(history.dropRight(1 ))
126
112
// check error stability (it always either shrinks, or increases with small tol)
127
113
assert(deltas.forall(x => (x._1 - x._2) <= 0.1 ))
@@ -133,63 +119,30 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
133
119
// Test predictions on a stream
134
120
test(" predictions" ) {
135
121
136
- val trainDir = Files .createTempDir()
137
- val testDir = Files .createTempDir()
138
- val batchDuration = Milliseconds (1000 )
139
- val numBatches = 10
140
- val nPoints = 100
141
-
142
- val ssc = new StreamingContext (sc, batchDuration)
143
- val data = ssc.textFileStream(trainDir.toString).map(LabeledPoint .parse)
122
+ // create model initialized with true weights
144
123
val model = new StreamingLinearRegressionWithSGD ()
145
- .setInitialWeights(Vectors .dense(0 .0 , 0 .0 ))
124
+ .setInitialWeights(Vectors .dense(10 .0 , 10 .0 ))
146
125
.setStepSize(0.1 )
147
- .setNumIterations(50 )
126
+ .setNumIterations(25 )
148
127
149
- model.trainOn(data)
150
-
151
- ssc.start()
152
-
153
- // write training data to a file stream
154
- for (i <- 0 until numBatches) {
155
- val samples = LinearDataGenerator .generateLinearInput(
156
- 0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
157
- val file = new File (trainDir, i.toString)
158
- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
159
- Thread .sleep(batchDuration.milliseconds)
160
- }
161
-
162
- ssc.stop(stopSparkContext= false )
163
-
164
- Utils .deleteRecursively(trainDir)
165
-
166
- print(model.latestModel().weights.toArray.mkString(" " ))
167
- print(model.latestModel().intercept)
168
-
169
- val ssc2 = new StreamingContext (sc, batchDuration)
170
- val data2 = ssc2.textFileStream(testDir.toString).map(LabeledPoint .parse)
171
-
172
- val history = new ArrayBuffer [Double ](numBatches)
173
- val predictions = model.predictOnValues(data2.map(x => (x.label, x.features)))
174
- val errors = predictions.map(x => math.abs(x._1 - x._2))
175
- errors.foreachRDD(rdd => history.append(rdd.reduce(_+_) / nPoints.toDouble))
176
-
177
- ssc2.start()
178
-
179
- // write test data to a file stream
180
-
181
- // make a function
182
- for (i <- 0 until numBatches) {
183
- val samples = LinearDataGenerator .generateLinearInput(
184
- 0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
185
- val file = new File (testDir, i.toString)
186
- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
187
- Thread .sleep(batchDuration.milliseconds)
128
+ // generate sequence of simulated data for testing
129
+ val numBatches = 10
130
+ val nPoints = 100
131
+ val testInput = (0 until numBatches).map { i =>
132
+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
188
133
}
189
134
190
- println(history)
191
-
192
- ssc2.stop(stopSparkContext= false )
135
+ // apply model predictions to test stream
136
+ val ssc = setupStreams(testInput, (inputDStream : DStream [LabeledPoint ]) => {
137
+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
138
+ })
139
+ // collect the output as (true, estimated) tuples
140
+ val output : Seq [Seq [(Double , Double )]] = runStreams(ssc, numBatches, numBatches)
141
+
142
+ // compute the mean absolute error and check that it's always less than 0.1
143
+ val errors = output.map(batch => batch.map(
144
+ p => math.abs(p._1 - p._2)).reduce(_+_) / nPoints.toDouble)
145
+ assert(errors.forall(x => x <= 0.1 ))
193
146
194
147
}
195
148
0 commit comments