Skip to content

Commit 7b14578

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-6267] [MLLIB] Python API for IsotonicRegression
https://issues.apache.org/jira/browse/SPARK-6267 Author: Yanbo Liang <ybliang8@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #5890 from yanboliang/spark-6267 and squashes the following commits: f20541d [Yanbo Liang] Merge pull request #3 from mengxr/SPARK-6267 7f202f9 [Xiangrui Meng] use Vector to have the best Python 2&3 compatibility 4bccfee [Yanbo Liang] fix doctest ec09412 [Yanbo Liang] fix typos 8214bbb [Yanbo Liang] fix code style 5c8ebe5 [Yanbo Liang] Python API for IsotonicRegression
1 parent ba2b566 commit 7b14578

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,24 @@ private[python] class PythonMLLibAPI extends Serializable {
282282
map(_.asInstanceOf[Object]).asJava
283283
}
284284

285+
/**
286+
* Java stub for Python mllib IsotonicRegression.run()
287+
*/
288+
def trainIsotonicRegressionModel(
289+
data: JavaRDD[Vector],
290+
isotonic: Boolean): JList[Object] = {
291+
val isotonicRegressionAlg = new IsotonicRegression().setIsotonic(isotonic)
292+
val input = data.rdd.map { x =>
293+
(x(0), x(1), x(2))
294+
}.persist(StorageLevel.MEMORY_AND_DISK)
295+
try {
296+
val model = isotonicRegressionAlg.run(input)
297+
List[AnyRef](model.boundaryVector, model.predictionVector).asJava
298+
} finally {
299+
data.rdd.unpersist(blocking = false)
300+
}
301+
}
302+
285303
/**
286304
* Java stub for Python mllib KMeans.run()
287305
*/

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,20 @@ import java.io.Serializable
2121
import java.lang.{Double => JDouble}
2222
import java.util.Arrays.binarySearch
2323

24+
import scala.collection.JavaConverters._
2425
import scala.collection.mutable.ArrayBuffer
2526

2627
import org.json4s._
2728
import org.json4s.JsonDSL._
2829
import org.json4s.jackson.JsonMethods._
2930

31+
import org.apache.spark.SparkContext
3032
import org.apache.spark.annotation.Experimental
3133
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
34+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3235
import org.apache.spark.mllib.util.{Loader, Saveable}
3336
import org.apache.spark.rdd.RDD
34-
import org.apache.spark.SparkContext
35-
import org.apache.spark.sql.{DataFrame, SQLContext}
37+
import org.apache.spark.sql.SQLContext
3638

3739
/**
3840
* :: Experimental ::
@@ -57,6 +59,13 @@ class IsotonicRegressionModel (
5759
assertOrdered(boundaries)
5860
assertOrdered(predictions)(predictionOrd)
5961

62+
/** A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. */
63+
def this(boundaries: java.lang.Iterable[Double],
64+
predictions: java.lang.Iterable[Double],
65+
isotonic: java.lang.Boolean) = {
66+
this(boundaries.asScala.toArray, predictions.asScala.toArray, isotonic)
67+
}
68+
6069
/** Asserts the input array is monotone with the given ordering. */
6170
private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = {
6271
var i = 1
@@ -132,6 +141,12 @@ class IsotonicRegressionModel (
132141
}
133142
}
134143

144+
/** A convenient method for boundaries called by the Python API. */
145+
private[mllib] def boundaryVector: Vector = Vectors.dense(boundaries)
146+
147+
/** A convenient method for boundaries called by the Python API. */
148+
private[mllib] def predictionVector: Vector = Vectors.dense(predictions)
149+
135150
override def save(sc: SparkContext, path: String): Unit = {
136151
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
137152
}

python/pyspark/mllib/regression.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
import numpy as np
1919
from numpy import array
2020

21+
from pyspark import RDD
2122
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
22-
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
23+
from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector
2324
from pyspark.mllib.util import Saveable, Loader
2425

2526
__all__ = ['LabeledPoint', 'LinearModel',
2627
'LinearRegressionModel', 'LinearRegressionWithSGD',
2728
'RidgeRegressionModel', 'RidgeRegressionWithSGD',
28-
'LassoModel', 'LassoWithSGD']
29+
'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel',
30+
'IsotonicRegression']
2931

3032

3133
class LabeledPoint(object):
@@ -396,6 +398,73 @@ def train(rdd, i):
396398
return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights)
397399

398400

401+
class IsotonicRegressionModel(Saveable, Loader):
402+
403+
"""Regression model for isotonic regression.
404+
405+
>>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)]
406+
>>> irm = IsotonicRegression.train(sc.parallelize(data))
407+
>>> irm.predict(3)
408+
2.0
409+
>>> irm.predict(5)
410+
16.5
411+
>>> irm.predict(sc.parallelize([3, 5])).collect()
412+
[2.0, 16.5]
413+
>>> import os, tempfile
414+
>>> path = tempfile.mkdtemp()
415+
>>> irm.save(sc, path)
416+
>>> sameModel = IsotonicRegressionModel.load(sc, path)
417+
>>> sameModel.predict(3)
418+
2.0
419+
>>> sameModel.predict(5)
420+
16.5
421+
>>> try:
422+
... os.removedirs(path)
423+
... except OSError:
424+
... pass
425+
"""
426+
427+
def __init__(self, boundaries, predictions, isotonic):
428+
self.boundaries = boundaries
429+
self.predictions = predictions
430+
self.isotonic = isotonic
431+
432+
def predict(self, x):
433+
if isinstance(x, RDD):
434+
return x.map(lambda v: self.predict(v))
435+
return np.interp(x, self.boundaries, self.predictions)
436+
437+
def save(self, sc, path):
438+
java_boundaries = _py2java(sc, self.boundaries.tolist())
439+
java_predictions = _py2java(sc, self.predictions.tolist())
440+
java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel(
441+
java_boundaries, java_predictions, self.isotonic)
442+
java_model.save(sc._jsc.sc(), path)
443+
444+
@classmethod
445+
def load(cls, sc, path):
446+
java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load(
447+
sc._jsc.sc(), path)
448+
py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray()
449+
py_predictions = _java2py(sc, java_model.predictionVector()).toArray()
450+
return IsotonicRegressionModel(py_boundaries, py_predictions, java_model.isotonic)
451+
452+
453+
class IsotonicRegression(object):
454+
"""
455+
Run IsotonicRegression algorithm to obtain isotonic regression model.
456+
457+
:param data: RDD of (label, feature, weight) tuples.
458+
:param isotonic: Whether this is isotonic or antitonic.
459+
"""
460+
@classmethod
461+
def train(cls, data, isotonic=True):
462+
"""Train a isotonic regression model on the given data."""
463+
boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel",
464+
data.map(_convert_to_vector), bool(isotonic))
465+
return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)
466+
467+
399468
def _test():
400469
import doctest
401470
from pyspark import SparkContext

0 commit comments

Comments
 (0)