Skip to content

Commit 1f2f723

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-5990] [MLLIB] Model import/export for IsotonicRegression
Model import/export for IsotonicRegression Author: Yanbo Liang <ybliang8@gmail.com> Closes apache#5270 from yanboliang/spark-5990 and squashes the following commits: 872028d [Yanbo Liang] fix code style f80ec1b [Yanbo Liang] address comments 49600cc [Yanbo Liang] address comments 429ff7d [Yanbo Liang] store each interval as a record 2b2f5a1 [Yanbo Liang] Model import/export for IsotonicRegression
1 parent ab9128f commit 1f2f723

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch
2323

2424
import scala.collection.mutable.ArrayBuffer
2525

26+
import org.json4s._
27+
import org.json4s.JsonDSL._
28+
import org.json4s.jackson.JsonMethods._
29+
2630
import org.apache.spark.annotation.Experimental
2731
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
32+
import org.apache.spark.mllib.util.{Loader, Saveable}
2833
import org.apache.spark.rdd.RDD
34+
import org.apache.spark.SparkContext
35+
import org.apache.spark.sql.{DataFrame, SQLContext}
2936

3037
/**
3138
* :: Experimental ::
@@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD
4249
class IsotonicRegressionModel (
4350
val boundaries: Array[Double],
4451
val predictions: Array[Double],
45-
val isotonic: Boolean) extends Serializable {
52+
val isotonic: Boolean) extends Serializable with Saveable {
4653

4754
private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse
4855

@@ -124,6 +131,75 @@ class IsotonicRegressionModel (
124131
predictions(foundIndex)
125132
}
126133
}
134+
135+
override def save(sc: SparkContext, path: String): Unit = {
136+
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
137+
}
138+
139+
override protected def formatVersion: String = "1.0"
140+
}
141+
142+
object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
143+
144+
import org.apache.spark.mllib.util.Loader._
145+
146+
private object SaveLoadV1_0 {
147+
148+
def thisFormatVersion: String = "1.0"
149+
150+
/** Hard-code class name string in case it changes in the future */
151+
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"
152+
153+
/** Model data for model import/export */
154+
case class Data(boundary: Double, prediction: Double)
155+
156+
def save(
157+
sc: SparkContext,
158+
path: String,
159+
boundaries: Array[Double],
160+
predictions: Array[Double],
161+
isotonic: Boolean): Unit = {
162+
val sqlContext = new SQLContext(sc)
163+
164+
val metadata = compact(render(
165+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
166+
("isotonic" -> isotonic)))
167+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
168+
169+
sqlContext.createDataFrame(
170+
boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
171+
).saveAsParquetFile(dataPath(path))
172+
}
173+
174+
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
175+
val sqlContext = new SQLContext(sc)
176+
val dataRDD = sqlContext.parquetFile(dataPath(path))
177+
178+
checkSchema[Data](dataRDD.schema)
179+
val dataArray = dataRDD.select("boundary", "prediction").collect()
180+
val (boundaries, predictions) = dataArray.map { x =>
181+
(x.getDouble(0), x.getDouble(1))
182+
}.toList.sortBy(_._1).unzip
183+
(boundaries.toArray, predictions.toArray)
184+
}
185+
}
186+
187+
override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
188+
implicit val formats = DefaultFormats
189+
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
190+
val isotonic = (metadata \ "isotonic").extract[Boolean]
191+
val classNameV1_0 = SaveLoadV1_0.thisClassName
192+
(loadedClassName, version) match {
193+
case (className, "1.0") if className == classNameV1_0 =>
194+
val (boundaries, predictions) = SaveLoadV1_0.load(sc, path)
195+
new IsotonicRegressionModel(boundaries, predictions, isotonic)
196+
case _ => throw new Exception(
197+
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
198+
s"($loadedClassName, $version). Supported:\n" +
199+
s" ($classNameV1_0, 1.0)"
200+
)
201+
}
202+
}
127203
}
128204

129205
/**

mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite}
2121

2222
import org.apache.spark.mllib.util.MLlibTestSparkContext
2323
import org.apache.spark.mllib.util.TestingUtils._
24+
import org.apache.spark.util.Utils
2425

2526
class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
2627

@@ -73,6 +74,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
7374
assert(model.isotonic)
7475
}
7576

77+
test("model save/load") {
78+
val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)
79+
val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)
80+
val model = new IsotonicRegressionModel(boundaries, predictions, true)
81+
82+
val tempDir = Utils.createTempDir()
83+
val path = tempDir.toURI.toString
84+
85+
// Save model, load it back, and compare.
86+
try {
87+
model.save(sc, path)
88+
val sameModel = IsotonicRegressionModel.load(sc, path)
89+
assert(model.boundaries === sameModel.boundaries)
90+
assert(model.predictions === sameModel.predictions)
91+
assert(model.isotonic === model.isotonic)
92+
} finally {
93+
Utils.deleteRecursively(tempDir)
94+
}
95+
}
96+
7697
test("isotonic regression with size 0") {
7798
val model = runIsotonicRegression(Seq(), true)
7899

0 commit comments

Comments
 (0)