|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.regression
|
19 | 19 |
|
20 |
| -import org.apache.spark.annotation.DeveloperApi |
21 | 20 | import org.apache.spark.Logging
|
| 21 | +import org.apache.spark.annotation.DeveloperApi |
| 22 | +import org.apache.spark.mllib.linalg.Vector |
22 | 23 | import org.apache.spark.streaming.dstream.DStream
|
23 | 24 |
|
24 | 25 | /**
|
@@ -92,15 +93,30 @@ abstract class StreamingLinearAlgorithm[
|
92 | 93 | /**
|
93 | 94 | * Use the model to make predictions on batches of data from a DStream
|
94 | 95 | *
|
95 |
| - * @param data DStream containing labeled data |
| 96 | + * @param data DStream containing feature vectors |
96 | 97 | * @return DStream containing predictions
|
97 | 98 | */
|
98 |
| - def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { |
| 99 | + def predictOn(data: DStream[Vector]): DStream[Double] = { |
99 | 100 | if (Option(model.weights) == None) {
|
100 |
| - logError("Initial weights must be set before starting prediction") |
101 |
| - throw new IllegalArgumentException |
| 101 | + val msg = "Initial weights must be set before starting prediction" |
| 102 | + logError(msg) |
| 103 | + throw new IllegalArgumentException(msg) |
102 | 104 | }
|
103 |
| - data.map(x => model.predict(x.features)) |
| 105 | + data.map(model.predict) |
104 | 106 | }
|
105 | 107 |
|
| 108 | + /** |
| 109 | + * Use the model to make predictions on the values of a DStream and carry over its keys. |
| 110 | + * @param data DStream containing feature vectors |
| 111 | + * @tparam K key type |
| 112 | + * @return DStream containing the input keys and the predictions as values |
| 113 | + */ |
| 114 | + def predictOnValues[K](data: DStream[(K, Vector)]): DStream[(K, Double)] = { |
| 115 | + if (Option(model.weights) == None) { |
| 116 | + val msg = "Initial weights must be set before starting prediction" |
| 117 | + logError(msg) |
| 118 | + throw new IllegalArgumentException(msg) |
| 119 | + } |
| 120 | + data.mapPartitions(_.map(x => (x._1, model.predict(x._2))), preservePartitioning = true) |
| 121 | + } |
106 | 122 | }
|
0 commit comments