Skip to content

Commit 0925efa

Browse files
committed
add predictOnValues to StreamingLR and fix predictOn
1 parent d1d0ee4 commit 0925efa

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ object StreamingLinearRegression {
5959
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
6060

6161
val model = new StreamingLinearRegressionWithSGD()
62-
.setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0)))
62+
.setInitialWeights(Vectors.zeros(args(3).toInt))
6363

6464
model.trainOn(trainingData)
65-
model.predictOn(testData).print()
65+
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
6666

6767
ssc.start()
6868
ssc.awaitTermination()

mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.apache.spark.annotation.DeveloperApi
2120
import org.apache.spark.Logging
21+
import org.apache.spark.annotation.DeveloperApi
22+
import org.apache.spark.mllib.linalg.Vector
2223
import org.apache.spark.streaming.dstream.DStream
2324

2425
/**
@@ -92,15 +93,30 @@ abstract class StreamingLinearAlgorithm[
9293
/**
9394
* Use the model to make predictions on batches of data from a DStream
9495
*
95-
* @param data DStream containing labeled data
96+
* @param data DStream containing feature vectors
9697
* @return DStream containing predictions
9798
*/
98-
def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
99+
def predictOn(data: DStream[Vector]): DStream[Double] = {
99100
if (Option(model.weights) == None) {
100-
logError("Initial weights must be set before starting prediction")
101-
throw new IllegalArgumentException
101+
val msg = "Initial weights must be set before starting prediction"
102+
logError(msg)
103+
throw new IllegalArgumentException(msg)
102104
}
103-
data.map(x => model.predict(x.features))
105+
data.map(model.predict)
104106
}
105107

108+
/**
109+
* Use the model to make predictions on the values of a DStream and carry over its keys.
110+
* @param data DStream containing feature vectors
111+
* @tparam K key type
112+
* @return DStream containing the input keys and the predictions as values
113+
*/
114+
def predictOnValues[K](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
115+
if (Option(model.weights) == None) {
116+
val msg = "Initial weights must be set before starting prediction"
117+
logError(msg)
118+
throw new IllegalArgumentException(msg)
119+
}
120+
data.mapPartitions(_.map(x => (x._1, model.predict(x._2))), preservePartitioning = true)
121+
}
106122
}

0 commit comments

Comments
 (0)