@@ -23,9 +23,15 @@ import java.util.Arrays.binarySearch
2323
2424import scala .collection .mutable .ArrayBuffer
2525
26+ import org .json4s .JsonDSL ._
27+ import org .json4s .jackson .JsonMethods ._
28+
2629import org .apache .spark .annotation .Experimental
2730import org .apache .spark .api .java .{JavaDoubleRDD , JavaRDD }
31+ import org .apache .spark .mllib .util .{Loader , Saveable }
2832import org .apache .spark .rdd .RDD
33+ import org .apache .spark .SparkContext
34+ import org .apache .spark .sql .{DataFrame , SQLContext }
2935
3036/**
3137 * :: Experimental ::
@@ -42,7 +48,7 @@ import org.apache.spark.rdd.RDD
4248class IsotonicRegressionModel (
4349 val boundaries : Array [Double ],
4450 val predictions : Array [Double ],
45- val isotonic : Boolean ) extends Serializable {
51+ val isotonic : Boolean ) extends Serializable with Saveable {
4652
4753 private val predictionOrd = if (isotonic) Ordering [Double ] else Ordering [Double ].reverse
4854
@@ -124,6 +130,71 @@ class IsotonicRegressionModel (
124130 predictions(foundIndex)
125131 }
126132 }
133+
134+ override def save (sc : SparkContext , path : String ): Unit = {
135+ val data = IsotonicRegressionModel .SaveLoadV1_0 .Data (boundaries, predictions, isotonic)
136+ IsotonicRegressionModel .SaveLoadV1_0 .save(sc, path, data)
137+ }
138+
139+ override protected def formatVersion : String = " 1.0"
140+ }
141+
142+ object IsotonicRegressionModel extends Loader [IsotonicRegressionModel ] {
143+
144+ import org .apache .spark .mllib .util .Loader ._
145+
146+ private object SaveLoadV1_0 {
147+
148+ def thisFormatVersion : String = " 1.0"
149+
150+ /** Hard-code class name string in case it changes in the future */
151+ def thisClassName : String = " org.apache.spark.mllib.regression.IsotonicRegressionModel"
152+
153+ /** Model data for model import/export */
154+ case class Data (boundaries : Array [Double ], predictions : Array [Double ], isotonic : Boolean )
155+
156+ def save (sc : SparkContext , path : String , data : Data ): Unit = {
157+ val sqlContext = new SQLContext (sc)
158+ import sqlContext .implicits ._
159+
160+ val metadata = compact(render(
161+ (" class" -> thisClassName) ~ (" version" -> thisFormatVersion)))
162+ sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
163+
164+ val dataRDD : DataFrame = sc.parallelize(Seq (data), 1 ).toDF()
165+ dataRDD.saveAsParquetFile(dataPath(path))
166+ }
167+
168+ def load (sc : SparkContext , path : String ): IsotonicRegressionModel = {
169+ val sqlContext = new SQLContext (sc)
170+ val dataRDD = sqlContext.parquetFile(dataPath(path))
171+
172+ checkSchema[Data ](dataRDD.schema)
173+ val dataArray = dataRDD.select(" boundaries" , " predictions" , " isotonic" ).take(1 )
174+ assert(dataArray.size == 1 ,
175+ s " Unable to load IsotonicRegressionModel data from: ${dataPath(path)}" )
176+ val data = dataArray(0 )
177+ val boundaries = data.getAs[Seq [Double ]](0 ).toArray
178+ val predictions = data.getAs[Seq [Double ]](1 ).toArray
179+ val isotonic = data.getAs[Boolean ](2 )
180+ new IsotonicRegressionModel (boundaries, predictions, isotonic)
181+ }
182+ }
183+
184+ override def load (sc : SparkContext , path : String ): IsotonicRegressionModel = {
185+ val (loadedClassName, version, metadata) = loadMetadata(sc, path)
186+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
187+ (loadedClassName, version) match {
188+ case (className, " 1.0" ) if className == classNameV1_0 =>
189+ val model = SaveLoadV1_0 .load(sc, path)
190+ model
191+ case _ => throw new Exception (
192+ s " IsotonicRegressionModel.load did not recognize model with (className, format version): " +
193+ s " ( $loadedClassName, $version). Supported: \n " +
194+ s " ( $classNameV1_0, 1.0) "
195+ )
196+ }
197+ }
127198}
128199
129200/**
0 commit comments