Skip to content

Commit 7f202f9

Browse files
committed
use Vector to have the best Python 2&3 compatibility
1 parent 4bccfee commit 7f202f9

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,12 @@ private[python] class PythonMLLibAPI extends Serializable {
289289
data: JavaRDD[Vector],
290290
isotonic: Boolean): JList[Object] = {
291291
val isotonicRegressionAlg = new IsotonicRegression().setIsotonic(isotonic)
292+
val input = data.rdd.map { x =>
293+
(x(0), x(1), x(2))
294+
}.persist(StorageLevel.MEMORY_AND_DISK)
292295
try {
293-
val model = isotonicRegressionAlg.run(data.rdd.map(_.toArray).map {
294-
x => (x(0), x(1), x(2)) }.persist(StorageLevel.MEMORY_AND_DISK))
295-
List(model.boundaries, model.predictions).map(_.asInstanceOf[Object]).asJava
296+
val model = isotonicRegressionAlg.run(input)
297+
List[AnyRef](model.boundaryVector, model.predictionVector).asJava
296298
} finally {
297299
data.rdd.unpersist(blocking = false)
298300
}

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ import org.json4s._
2828
import org.json4s.JsonDSL._
2929
import org.json4s.jackson.JsonMethods._
3030

31+
import org.apache.spark.SparkContext
3132
import org.apache.spark.annotation.Experimental
3233
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
34+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3335
import org.apache.spark.mllib.util.{Loader, Saveable}
3436
import org.apache.spark.rdd.RDD
35-
import org.apache.spark.SparkContext
36-
import org.apache.spark.sql.{DataFrame, SQLContext}
37+
import org.apache.spark.sql.SQLContext
3738

3839
/**
3940
* :: Experimental ::
@@ -140,6 +141,12 @@ class IsotonicRegressionModel (
140141
}
141142
}
142143

144+
/** A convenient method for boundaries called by the Python API. */
145+
private[mllib] def boundaryVector: Vector = Vectors.dense(boundaries)
146+
147+
/** A convenient method for boundaries called by the Python API. */
148+
private[mllib] def predictionVector: Vector = Vectors.dense(predictions)
149+
143150
override def save(sc: SparkContext, path: String): Unit = {
144151
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
145152
}

python/pyspark/mllib/regression.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -445,25 +445,24 @@ def save(self, sc, path):
445445
def load(cls, sc, path):
446446
java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load(
447447
sc._jsc.sc(), path)
448-
py_boundaries = _java2py(sc, java_model.boundaries())
449-
py_predictions = _java2py(sc, java_model.predictions())
450-
return IsotonicRegressionModel(np.array(py_boundaries),
451-
np.array(py_predictions), java_model.isotonic)
448+
py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray()
449+
py_predictions = _java2py(sc, java_model.predictionVector()).toArray()
450+
return IsotonicRegressionModel(py_boundaries, py_predictions, java_model.isotonic)
452451

453452

454453
class IsotonicRegression(object):
455454
"""
456455
Run IsotonicRegression algorithm to obtain isotonic regression model.
457456
458-
:param data: RDD of data points
457+
:param data: RDD of (label, feature, weight) tuples.
459458
:param isotonic: Whether this is isotonic or antitonic.
460459
"""
461460
@classmethod
462461
def train(cls, data, isotonic=True):
463462
"""Train a isotonic regression model on the given data."""
464463
boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel",
465464
data.map(_convert_to_vector), bool(isotonic))
466-
return IsotonicRegressionModel(np.array(boundaries), np.array(predictions), isotonic)
465+
return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)
467466

468467

469468
def _test():

0 commit comments

Comments
 (0)