Skip to content

Commit 9ae9d53

Browse files
SPARK-3278 changes after PR feedback apache#3519. Binary search used for isotonic regression model predictions
1 parent fad4bf9 commit 9ae9d53

File tree

3 files changed

+47
-33
lines changed

3 files changed

+47
-33
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,25 @@
1818
package org.apache.spark.mllib.regression
1919

2020
import java.io.Serializable
21+
import java.util.Arrays.binarySearch
2122

2223
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
2324
import org.apache.spark.rdd.RDD
2425

2526
/**
2627
* Regression model for Isotonic regression
2728
*
28-
* @param predictions Weights computed for every feature.
29-
* @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
29+
* @param features Array of features.
30+
* @param labels Array of labels associated to the features at the same index.
3031
*/
3132
class IsotonicRegressionModel (
32-
val predictions: Seq[(Double, Double, Double)],
33-
val isotonic: Boolean)
33+
features: Array[Double],
34+
val labels: Array[Double])
3435
extends Serializable {
3536

3637
/**
3738
* Predict labels for provided features
39+
* Using a piecewise constant function
3840
*
3941
* @param testData features to be labeled
4042
* @return predicted labels
@@ -44,6 +46,7 @@ class IsotonicRegressionModel (
4446

4547
/**
4648
* Predict labels for provided features
49+
* Using a piecewise constant function
4750
*
4851
* @param testData features to be labeled
4952
* @return predicted labels
@@ -53,13 +56,25 @@ class IsotonicRegressionModel (
5356

5457
/**
5558
* Predict a single label
59+
* Using a piecewise constant function
5660
*
5761
* @param testData feature to be labeled
5862
* @return predicted label
5963
*/
60-
def predict(testData: Double): Double =
61-
// Take the highest of data points smaller than our feature or data point with lowest feature
62-
(predictions.head +: predictions.filter(y => y._2 <= testData)).last._1
64+
def predict(testData: Double): Double = {
65+
val result = binarySearch(features, testData)
66+
67+
val index =
68+
if (result == -1) {
69+
0
70+
} else if (result < 0) {
71+
-result - 2
72+
} else {
73+
result
74+
}
75+
76+
labels(index)
77+
}
6378
}
6479

6580
/**
@@ -93,9 +108,13 @@ class IsotonicRegression
93108
* @return isotonic regression model
94109
*/
95110
protected def createModel(
96-
predictions: Seq[(Double, Double, Double)],
111+
predictions: Array[(Double, Double, Double)],
97112
isotonic: Boolean): IsotonicRegressionModel = {
98-
new IsotonicRegressionModel(predictions, isotonic)
113+
114+
val labels = predictions.map(_._1)
115+
val features = predictions.map(_._2)
116+
117+
new IsotonicRegressionModel(features, labels)
99118
}
100119

101120
/**
@@ -167,7 +186,7 @@ class IsotonicRegression
167186
*/
168187
private def parallelPoolAdjacentViolators(
169188
testData: RDD[(Double, Double, Double)],
170-
isotonic: Boolean): Seq[(Double, Double, Double)] = {
189+
isotonic: Boolean): Array[(Double, Double, Double)] = {
171190

172191
val parallelStepResult = testData
173192
.sortBy(_._2)
@@ -213,7 +232,7 @@ object IsotonicRegression {
213232
isotonic: Boolean): IsotonicRegressionModel = {
214233
new IsotonicRegression()
215234
.run(
216-
input.rdd.map(x => (x._1.doubleValue(), x._2.doubleValue(), x._3.doubleValue())),
235+
input.rdd.asInstanceOf[RDD[(Double, Double, Double)]],
217236
isotonic)
218237
}
219238
}

mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void tearDown() {
4949
double difference(List<Tuple3<Double, Double, Double>> expected, IsotonicRegressionModel model) {
5050
double diff = 0;
5151

52-
for(int i = 0; i < model.predictions().length(); i++) {
52+
for(int i = 0; i < model.labels().length; i++) {
5353
Tuple3<Double, Double, Double> exp = expected.get(i);
5454
diff += Math.abs(model.predict(exp._2()) - exp._1());
5555
}

mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ class IsotonicRegressionSuite
3838
val alg = new IsotonicRegression
3939
val model = alg.run(trainRDD, true)
4040

41-
model.predictions should be(
42-
generateIsotonicInput(
43-
1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12, 14, 15, 16.5, 16.5, 17, 18, 19, 20))
41+
model.labels should be(
42+
Array(1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12, 14, 15, 16.5, 16.5, 17, 18, 19, 20))
4443
}
4544

4645
test("increasing isotonic regression using api") {
@@ -50,9 +49,8 @@ class IsotonicRegressionSuite
5049

5150
val model = IsotonicRegression.train(trainRDD, true)
5251

53-
model.predictions should be(
54-
generateIsotonicInput(
55-
1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12, 14, 15, 16.5, 16.5, 17, 18, 19, 20))
52+
model.labels should be(
53+
Array(1, 2, 7d/3, 7d/3, 7d/3, 6, 7, 8, 10, 10, 10, 12, 14, 15, 16.5, 16.5, 17, 18, 19, 20))
5654
}
5755

5856
test("isotonic regression with size 0") {
@@ -61,7 +59,7 @@ class IsotonicRegressionSuite
6159
val alg = new IsotonicRegression
6260
val model = alg.run(trainRDD, true)
6361

64-
model.predictions should be(List())
62+
model.labels should be(Array())
6563
}
6664

6765
test("isotonic regression with size 1") {
@@ -70,7 +68,7 @@ class IsotonicRegressionSuite
7068
val alg = new IsotonicRegression
7169
val model = alg.run(trainRDD, true)
7270

73-
model.predictions should be(generateIsotonicInput(1))
71+
model.labels should be(Array(1.0))
7472
}
7573

7674
test("isotonic regression strictly increasing sequence") {
@@ -79,7 +77,7 @@ class IsotonicRegressionSuite
7977
val alg = new IsotonicRegression
8078
val model = alg.run(trainRDD, true)
8179

82-
model.predictions should be(generateIsotonicInput(1, 2, 3, 4, 5))
80+
model.labels should be(Array(1, 2, 3, 4, 5))
8381
}
8482

8583
test("isotonic regression strictly decreasing sequence") {
@@ -88,7 +86,7 @@ class IsotonicRegressionSuite
8886
val alg = new IsotonicRegression
8987
val model = alg.run(trainRDD, true)
9088

91-
model.predictions should be(generateIsotonicInput(3, 3, 3, 3, 3))
89+
model.labels should be(Array(3, 3, 3, 3, 3))
9290
}
9391

9492
test("isotonic regression with last element violating monotonicity") {
@@ -97,7 +95,7 @@ class IsotonicRegressionSuite
9795
val alg = new IsotonicRegression
9896
val model = alg.run(trainRDD, true)
9997

100-
model.predictions should be(generateIsotonicInput(1, 2, 3, 3, 3))
98+
model.labels should be(Array(1, 2, 3, 3, 3))
10199
}
102100

103101
test("isotonic regression with first element violating monotonicity") {
@@ -106,7 +104,7 @@ class IsotonicRegressionSuite
106104
val alg = new IsotonicRegression
107105
val model = alg.run(trainRDD, true)
108106

109-
model.predictions should be(generateIsotonicInput(3, 3, 3, 4, 5))
107+
model.labels should be(Array(3, 3, 3, 4, 5))
110108
}
111109

112110
test("isotonic regression with negative labels") {
@@ -115,7 +113,7 @@ class IsotonicRegressionSuite
115113
val alg = new IsotonicRegression
116114
val model = alg.run(trainRDD, true)
117115

118-
model.predictions should be(generateIsotonicInput(-1.5, -1.5, 0, 0, 0))
116+
model.labels should be(Array(-1.5, -1.5, 0, 0, 0))
119117
}
120118

121119
test("isotonic regression with unordered input") {
@@ -124,7 +122,7 @@ class IsotonicRegressionSuite
124122
val alg = new IsotonicRegression
125123
val model = alg.run(trainRDD, true)
126124

127-
model.predictions should be(generateIsotonicInput(1, 2, 3, 4, 5))
125+
model.labels should be(Array(1, 2, 3, 4, 5))
128126
}
129127

130128
test("weighted isotonic regression") {
@@ -134,8 +132,7 @@ class IsotonicRegressionSuite
134132
val alg = new IsotonicRegression
135133
val model = alg.run(trainRDD, true)
136134

137-
model.predictions should be(
138-
generateWeightedIsotonicInput(Seq(1, 2, 2.75, 2.75,2.75), Seq(1, 1, 1, 1, 2)))
135+
model.labels should be(Array(1, 2, 2.75, 2.75,2.75))
139136
}
140137

141138
test("weighted isotonic regression with weights lower than 1") {
@@ -145,8 +142,7 @@ class IsotonicRegressionSuite
145142
val alg = new IsotonicRegression
146143
val model = alg.run(trainRDD, true)
147144

148-
model.predictions.map(p => p.copy(_1 = round(p._1))) should be(
149-
generateWeightedIsotonicInput(Seq(1, 2, 3.3/1.2, 3.3/1.2, 3.3/1.2), Seq(1, 1, 1, 0.1, 0.1)))
145+
model.labels.map(round) should be(Array(1, 2, 3.3/1.2, 3.3/1.2, 3.3/1.2))
150146
}
151147

152148
test("weighted isotonic regression with negative weights") {
@@ -155,8 +151,7 @@ class IsotonicRegressionSuite
155151
val alg = new IsotonicRegression
156152
val model = alg.run(trainRDD, true)
157153

158-
model.predictions should be(
159-
generateWeightedIsotonicInput(Seq(1.0, 10.0/6, 10.0/6, 10.0/6, 10.0/6), Seq(-1, 1, -3, 1, -5)))
154+
model.labels should be(Array(1.0, 10.0/6, 10.0/6, 10.0/6, 10.0/6))
160155
}
161156

162157
test("weighted isotonic regression with zero weights") {
@@ -165,7 +160,7 @@ class IsotonicRegressionSuite
165160
val alg = new IsotonicRegression
166161
val model = alg.run(trainRDD, true)
167162

168-
model.predictions should be(generateWeightedIsotonicInput(Seq(1, 2, 2, 2, 2), Seq(0, 0, 0, 1, 0)))
163+
model.labels should be(Array(1, 2, 2, 2, 2))
169164
}
170165

171166
test("isotonic regression prediction") {

0 commit comments

Comments
 (0)