Skip to content

Commit 3c2954b

Browse files
SPARK-3278 Isotonic regression java api
1 parent 45aa7e8 commit 3c2954b

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class IsotonicRegressionModel (
4646
* @param testData features to be labeled
4747
* @return predicted labels
4848
*/
49-
def predict(testData: JavaRDD[java.lang.Double]): RDD[java.lang.Double] =
50-
testData.rdd.map(x => x.doubleValue()).map(predict)
49+
def predict(testData: JavaRDD[java.lang.Double]): JavaRDD[java.lang.Double] =
50+
testData.rdd.map(_.doubleValue()).map(predict).map(new java.lang.Double(_))
5151

5252
/**
5353
* Predict a single label

mllib/src/main/scala/org/apache/spark/mllib/util/IsotonicDataGenerator.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,28 @@ import org.apache.spark.annotation.DeveloperApi
2121
import scala.collection.JavaConversions._
2222
import java.lang.{Double => JDouble}
2323

24+
/**
25+
* :: DeveloperApi ::
26+
* Generate test data for Isotonic regresision.
27+
*/
2428
@DeveloperApi
2529
object IsotonicDataGenerator {
2630

2731
/**
2832
* Return a Java List of ordered labeled points
33+
*
2934
* @param labels list of labels for the data points
3035
* @return Java List of input.
3136
*/
3237
def generateIsotonicInputAsList(labels: Array[Double]): java.util.List[(JDouble, JDouble)] = {
33-
seqAsJavaList(generateIsotonicInput(wrapDoubleArray(labels):_*).map(x => (new JDouble(x._1), new JDouble(x._2))))
38+
seqAsJavaList(
39+
generateIsotonicInput(
40+
wrapDoubleArray(labels):_*).map(x => (new JDouble(x._1), new JDouble(x._2))))
3441
}
3542

3643
/**
3744
* Return an ordered sequence of labeled data points with default weights
45+
*
3846
* @param labels list of labels for the data points
3947
* @return sequence of data points
4048
*/
@@ -45,6 +53,7 @@ object IsotonicDataGenerator {
4553

4654
/**
4755
* Return an ordered sequence of labeled weighted data points
56+
*
4857
* @param labels list of labels for the data points
4958
* @param weights list of weights for the data points
5059
* @return sequence of data points

mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ public Double call(Tuple2<Double, Double> v) throws Exception {
8686
}
8787
});
8888

89-
Double[] predictions = model.predict(testRDD).collect();
89+
List<Double> predictions = model.predict(testRDD).collect();
9090

91-
Assert.assertTrue(predictions[0] == 1d);
92-
Assert.assertTrue(predictions[11] == 12d);
91+
Assert.assertTrue(predictions.get(0) == 1d);
92+
Assert.assertTrue(predictions.get(11) == 12d);
9393
}
9494
}
9595

0 commit comments

Comments
 (0)