@@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch
2323
2424import scala .collection .mutable .ArrayBuffer
2525
26+ import org .json4s ._
27+ import org .json4s .JsonDSL ._
28+ import org .json4s .jackson .JsonMethods ._
29+
2630import org .apache .spark .annotation .Experimental
2731import org .apache .spark .api .java .{JavaDoubleRDD , JavaRDD }
32+ import org .apache .spark .mllib .util .{Loader , Saveable }
2833import org .apache .spark .rdd .RDD
34+ import org .apache .spark .SparkContext
35+ import org .apache .spark .sql .{DataFrame , SQLContext }
2936
3037/**
3138 * :: Experimental ::
@@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD
4249class IsotonicRegressionModel (
4350 val boundaries : Array [Double ],
4451 val predictions : Array [Double ],
45- val isotonic : Boolean ) extends Serializable {
52+ val isotonic : Boolean ) extends Serializable with Saveable {
4653
4754 private val predictionOrd = if (isotonic) Ordering [Double ] else Ordering [Double ].reverse
4855
@@ -124,6 +131,75 @@ class IsotonicRegressionModel (
124131 predictions(foundIndex)
125132 }
126133 }
134+
135+ override def save (sc : SparkContext , path : String ): Unit = {
136+ IsotonicRegressionModel .SaveLoadV1_0 .save(sc, path, boundaries, predictions, isotonic)
137+ }
138+
139+ override protected def formatVersion : String = " 1.0"
140+ }
141+
142+ object IsotonicRegressionModel extends Loader [IsotonicRegressionModel ] {
143+
144+ import org .apache .spark .mllib .util .Loader ._
145+
146+ private object SaveLoadV1_0 {
147+
148+ def thisFormatVersion : String = " 1.0"
149+
150+ /** Hard-code class name string in case it changes in the future */
151+ def thisClassName : String = " org.apache.spark.mllib.regression.IsotonicRegressionModel"
152+
153+ /** Model data for model import/export */
154+ case class Data (boundary : Double , prediction : Double )
155+
156+ def save (
157+ sc : SparkContext ,
158+ path : String ,
159+ boundaries : Array [Double ],
160+ predictions : Array [Double ],
161+ isotonic : Boolean ): Unit = {
162+ val sqlContext = new SQLContext (sc)
163+
164+ val metadata = compact(render(
165+ (" class" -> thisClassName) ~ (" version" -> thisFormatVersion) ~
166+ (" isotonic" -> isotonic)))
167+ sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
168+
169+ sqlContext.createDataFrame(
170+ boundaries.toSeq.zip(predictions).map { case (b, p) => Data (b, p) }
171+ ).saveAsParquetFile(dataPath(path))
172+ }
173+
174+ def load (sc : SparkContext , path : String ): (Array [Double ], Array [Double ]) = {
175+ val sqlContext = new SQLContext (sc)
176+ val dataRDD = sqlContext.parquetFile(dataPath(path))
177+
178+ checkSchema[Data ](dataRDD.schema)
179+ val dataArray = dataRDD.select(" boundary" , " prediction" ).collect()
180+ val (boundaries, predictions) = dataArray.map { x =>
181+ (x.getDouble(0 ), x.getDouble(1 ))
182+ }.toList.sortBy(_._1).unzip
183+ (boundaries.toArray, predictions.toArray)
184+ }
185+ }
186+
187+ override def load (sc : SparkContext , path : String ): IsotonicRegressionModel = {
188+ implicit val formats = DefaultFormats
189+ val (loadedClassName, version, metadata) = loadMetadata(sc, path)
190+ val isotonic = (metadata \ " isotonic" ).extract[Boolean ]
191+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
192+ (loadedClassName, version) match {
193+ case (className, " 1.0" ) if className == classNameV1_0 =>
194+ val (boundaries, predictions) = SaveLoadV1_0 .load(sc, path)
195+ new IsotonicRegressionModel (boundaries, predictions, isotonic)
196+ case _ => throw new Exception (
197+ s " IsotonicRegressionModel.load did not recognize model with (className, format version): " +
198+ s " ( $loadedClassName, $version). Supported: \n " +
199+ s " ( $classNameV1_0, 1.0) "
200+ )
201+ }
202+ }
127203}
128204
129205/**
0 commit comments