@@ -49,7 +49,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
49
49
}
50
50
51
51
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
52
- test(" streaming linear regression parameter accuracy" ) {
52
+ test(" parameter accuracy" ) {
53
53
54
54
val testDir = Files .createTempDir()
55
55
val numBatches = 10
@@ -76,7 +76,6 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
76
76
77
77
ssc.stop(stopSparkContext= false )
78
78
79
- System .clearProperty(" spark.driver.port" )
80
79
Utils .deleteRecursively(testDir)
81
80
82
81
// check accuracy of final parameter estimates
@@ -91,7 +90,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
91
90
}
92
91
93
92
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
94
- test(" streaming linear regression parameter convergence" ) {
93
+ test(" parameter convergence" ) {
95
94
96
95
val testDir = Files .createTempDir()
97
96
val batchDuration = Milliseconds (2000 )
@@ -121,7 +120,6 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
121
120
122
121
ssc.stop(stopSparkContext= false )
123
122
124
- System .clearProperty(" spark.driver.port" )
125
123
Utils .deleteRecursively(testDir)
126
124
127
125
val deltas = history.drop(1 ).zip(history.dropRight(1 ))
@@ -132,4 +130,67 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
132
130
133
131
}
134
132
133
+ // Test predictions on a stream
134
+ test(" predictions" ) {
135
+
136
+ val trainDir = Files .createTempDir()
137
+ val testDir = Files .createTempDir()
138
+ val batchDuration = Milliseconds (1000 )
139
+ val numBatches = 10
140
+ val nPoints = 100
141
+
142
+ val ssc = new StreamingContext (sc, batchDuration)
143
+ val data = ssc.textFileStream(trainDir.toString).map(LabeledPoint .parse)
144
+ val model = new StreamingLinearRegressionWithSGD ()
145
+ .setInitialWeights(Vectors .dense(0.0 , 0.0 ))
146
+ .setStepSize(0.1 )
147
+ .setNumIterations(50 )
148
+
149
+ model.trainOn(data)
150
+
151
+ ssc.start()
152
+
153
+ // write training data to a file stream
154
+ for (i <- 0 until numBatches) {
155
+ val samples = LinearDataGenerator .generateLinearInput(
156
+ 0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
157
+ val file = new File (trainDir, i.toString)
158
+ Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
159
+ Thread .sleep(batchDuration.milliseconds)
160
+ }
161
+
162
+ ssc.stop(stopSparkContext= false )
163
+
164
+ Utils .deleteRecursively(trainDir)
165
+
166
+ print(model.latestModel().weights.toArray.mkString(" " ))
167
+ print(model.latestModel().intercept)
168
+
169
+ val ssc2 = new StreamingContext (sc, batchDuration)
170
+ val data2 = ssc2.textFileStream(testDir.toString).map(LabeledPoint .parse)
171
+
172
+ val history = new ArrayBuffer [Double ](numBatches)
173
+ val predictions = model.predictOnValues(data2.map(x => (x.label, x.features)))
174
+ val errors = predictions.map(x => math.abs(x._1 - x._2))
175
+ errors.foreachRDD(rdd => history.append(rdd.reduce(_+_) / nPoints.toDouble))
176
+
177
+ ssc2.start()
178
+
179
+ // write test data to a file stream
180
+
181
+ // make a function
182
+ for (i <- 0 until numBatches) {
183
+ val samples = LinearDataGenerator .generateLinearInput(
184
+ 0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
185
+ val file = new File (testDir, i.toString)
186
+ Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
187
+ Thread .sleep(batchDuration.milliseconds)
188
+ }
189
+
190
+ println(history)
191
+
192
+ ssc2.stop(stopSparkContext= false )
193
+
194
+ }
195
+
135
196
}
0 commit comments