Skip to content

Commit b8b1620

Browse files
Removed WeightedLabeledPoint. Replaced by tuple of doubles
1 parent 34760d5 commit b8b1620

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object MonotonicityConstraint {
6666
* @param monotonicityConstraint specifies if the sequence is increasing or decreasing
6767
*/
6868
class IsotonicRegressionModel(
69-
val predictions: Seq[WeightedLabeledPoint],
69+
val predictions: Seq[(Double, Double, Double)],
7070
val monotonicityConstraint: MonotonicityConstraint)
7171
extends RegressionModel {
7272

@@ -76,7 +76,7 @@ class IsotonicRegressionModel(
7676
override def predict(testData: Vector): Double = {
7777
// Take the highest of data points smaller than our feature or data point with lowest feature
7878
(predictions.head +:
79-
predictions.filter(y => y.features.toArray.head <= testData.toArray.head)).last.label
79+
predictions.filter(y => y._2 <= testData.toArray.head)).last._1
8080
}
8181
}
8282

@@ -95,7 +95,7 @@ trait IsotonicRegressionAlgorithm
9595
* @return isotonic regression model
9696
*/
9797
protected def createModel(
98-
predictions: Seq[WeightedLabeledPoint],
98+
predictions: Seq[(Double, Double, Double)],
9999
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
100100

101101
/**
@@ -106,7 +106,7 @@ trait IsotonicRegressionAlgorithm
106106
* @return isotonic regression model
107107
*/
108108
def run(
109-
input: RDD[WeightedLabeledPoint],
109+
input: RDD[(Double, Double, Double)],
110110
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
111111
}
112112

@@ -117,15 +117,15 @@ class PoolAdjacentViolators private [mllib]
117117
extends IsotonicRegressionAlgorithm {
118118

119119
override def run(
120-
input: RDD[WeightedLabeledPoint],
120+
input: RDD[(Double, Double, Double)],
121121
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
122122
createModel(
123123
parallelPoolAdjacentViolators(input, monotonicityConstraint),
124124
monotonicityConstraint)
125125
}
126126

127127
override protected def createModel(
128-
predictions: Seq[WeightedLabeledPoint],
128+
predictions: Seq[(Double, Double, Double)],
129129
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
130130
new IsotonicRegressionModel(predictions, monotonicityConstraint)
131131
}
@@ -194,12 +194,12 @@ class PoolAdjacentViolators private [mllib]
194194
* @return result
195195
*/
196196
private def parallelPoolAdjacentViolators(
197-
testData: RDD[WeightedLabeledPoint],
198-
monotonicityConstraint: MonotonicityConstraint): Seq[WeightedLabeledPoint] = {
197+
testData: RDD[(Double, Double, Double)],
198+
monotonicityConstraint: MonotonicityConstraint): Seq[(Double, Double, Double)] = {
199199

200200
poolAdjacentViolators(
201201
testData
202-
.sortBy(_.features.toArray.head)
202+
.sortBy(_._2)
203203
.cache()
204204
.mapPartitions(it => poolAdjacentViolators(it.toArray, monotonicityConstraint).toIterator)
205205
.collect(), monotonicityConstraint)
@@ -224,7 +224,7 @@ object IsotonicRegression {
224224
* @param monotonicityConstraint Isotonic (increasing) or Antitonic (decreasing) sequence
225225
*/
226226
def train(
227-
input: RDD[WeightedLabeledPoint],
227+
input: RDD[(Double, Double, Double)],
228228
monotonicityConstraint: MonotonicityConstraint = Isotonic): IsotonicRegressionModel = {
229229
new PoolAdjacentViolators().run(input, monotonicityConstraint)
230230
}

0 commit comments

Comments
 (0)