Skip to content

Commit 2b2f5a1

Browse files
committed
Model import/export for IsotonicRegression
1 parent 19d4c39 commit 2b2f5a1

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ import java.util.Arrays.binarySearch
2323

2424
import scala.collection.mutable.ArrayBuffer
2525

26+
import org.json4s.JsonDSL._
27+
import org.json4s.jackson.JsonMethods._
28+
2629
import org.apache.spark.annotation.Experimental
2730
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
31+
import org.apache.spark.mllib.util.{Loader, Saveable}
2832
import org.apache.spark.rdd.RDD
33+
import org.apache.spark.SparkContext
34+
import org.apache.spark.sql.{DataFrame, SQLContext}
2935

3036
/**
3137
* :: Experimental ::
@@ -42,7 +48,7 @@ import org.apache.spark.rdd.RDD
4248
class IsotonicRegressionModel (
4349
val boundaries: Array[Double],
4450
val predictions: Array[Double],
45-
val isotonic: Boolean) extends Serializable {
51+
val isotonic: Boolean) extends Serializable with Saveable {
4652

4753
private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse
4854

@@ -124,6 +130,71 @@ class IsotonicRegressionModel (
124130
predictions(foundIndex)
125131
}
126132
}
133+
134+
override def save(sc: SparkContext, path: String): Unit = {
135+
val data = IsotonicRegressionModel.SaveLoadV1_0.Data(boundaries, predictions, isotonic)
136+
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data)
137+
}
138+
139+
override protected def formatVersion: String = "1.0"
140+
}
141+
142+
object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
143+
144+
import org.apache.spark.mllib.util.Loader._
145+
146+
private object SaveLoadV1_0 {
147+
148+
def thisFormatVersion: String = "1.0"
149+
150+
/** Hard-code class name string in case it changes in the future */
151+
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"
152+
153+
/** Model data for model import/export */
154+
case class Data(boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean)
155+
156+
def save(sc: SparkContext, path: String, data: Data): Unit = {
157+
val sqlContext = new SQLContext(sc)
158+
import sqlContext.implicits._
159+
160+
val metadata = compact(render(
161+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
162+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
163+
164+
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
165+
dataRDD.saveAsParquetFile(dataPath(path))
166+
}
167+
168+
def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
169+
val sqlContext = new SQLContext(sc)
170+
val dataRDD = sqlContext.parquetFile(dataPath(path))
171+
172+
checkSchema[Data](dataRDD.schema)
173+
val dataArray = dataRDD.select("boundaries", "predictions", "isotonic").take(1)
174+
assert(dataArray.size == 1,
175+
s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}")
176+
val data = dataArray(0)
177+
val boundaries = data.getAs[Seq[Double]](0).toArray
178+
val predictions = data.getAs[Seq[Double]](1).toArray
179+
val isotonic = data.getAs[Boolean](2)
180+
new IsotonicRegressionModel(boundaries, predictions, isotonic)
181+
}
182+
}
183+
184+
override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
185+
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
186+
val classNameV1_0 = SaveLoadV1_0.thisClassName
187+
(loadedClassName, version) match {
188+
case (className, "1.0") if className == classNameV1_0 =>
189+
val model = SaveLoadV1_0.load(sc, path)
190+
model
191+
case _ => throw new Exception(
192+
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
193+
s"($loadedClassName, $version). Supported:\n" +
194+
s" ($classNameV1_0, 1.0)"
195+
)
196+
}
197+
}
127198
}
128199

129200
/**

mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite}
2121

2222
import org.apache.spark.mllib.util.MLlibTestSparkContext
2323
import org.apache.spark.mllib.util.TestingUtils._
24+
import org.apache.spark.util.Utils
2425

2526
class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
2627

@@ -73,6 +74,24 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
7374
assert(model.isotonic)
7475
}
7576

77+
test("model save/load") {
78+
val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true)
79+
80+
val tempDir = Utils.createTempDir()
81+
val path = tempDir.toURI.toString
82+
83+
// Save model, load it back, and compare.
84+
try {
85+
model.save(sc, path)
86+
val sameModel = IsotonicRegressionModel.load(sc, path)
87+
assert(model.boundaries === sameModel.boundaries)
88+
assert(model.predictions === sameModel.predictions)
89+
assert(model.isotonic == model.isotonic)
90+
} finally {
91+
Utils.deleteRecursively(tempDir)
92+
}
93+
}
94+
7695
test("isotonic regression with size 0") {
7796
val model = runIsotonicRegression(Seq(), true)
7897

0 commit comments

Comments
 (0)