@@ -23,9 +23,15 @@ import java.util.Arrays.binarySearch
23
23
24
24
import scala .collection .mutable .ArrayBuffer
25
25
26
+ import org .json4s .JsonDSL ._
27
+ import org .json4s .jackson .JsonMethods ._
28
+
26
29
import org .apache .spark .annotation .Experimental
27
30
import org .apache .spark .api .java .{JavaDoubleRDD , JavaRDD }
31
+ import org .apache .spark .mllib .util .{Loader , Saveable }
28
32
import org .apache .spark .rdd .RDD
33
+ import org .apache .spark .SparkContext
34
+ import org .apache .spark .sql .{DataFrame , SQLContext }
29
35
30
36
/**
31
37
* :: Experimental ::
@@ -42,7 +48,7 @@ import org.apache.spark.rdd.RDD
42
48
class IsotonicRegressionModel (
43
49
val boundaries : Array [Double ],
44
50
val predictions : Array [Double ],
45
- val isotonic : Boolean ) extends Serializable {
51
+ val isotonic : Boolean ) extends Serializable with Saveable {
46
52
47
53
private val predictionOrd = if (isotonic) Ordering [Double ] else Ordering [Double ].reverse
48
54
@@ -124,6 +130,71 @@ class IsotonicRegressionModel (
124
130
predictions(foundIndex)
125
131
}
126
132
}
133
+
134
+ override def save (sc : SparkContext , path : String ): Unit = {
135
+ val data = IsotonicRegressionModel .SaveLoadV1_0 .Data (boundaries, predictions, isotonic)
136
+ IsotonicRegressionModel .SaveLoadV1_0 .save(sc, path, data)
137
+ }
138
+
139
+ override protected def formatVersion : String = " 1.0"
140
+ }
141
+
142
+ object IsotonicRegressionModel extends Loader [IsotonicRegressionModel ] {
143
+
144
+ import org .apache .spark .mllib .util .Loader ._
145
+
146
+ private object SaveLoadV1_0 {
147
+
148
+ def thisFormatVersion : String = " 1.0"
149
+
150
+ /** Hard-code class name string in case it changes in the future */
151
+ def thisClassName : String = " org.apache.spark.mllib.regression.IsotonicRegressionModel"
152
+
153
+ /** Model data for model import/export */
154
+ case class Data (boundaries : Array [Double ], predictions : Array [Double ], isotonic : Boolean )
155
+
156
+ def save (sc : SparkContext , path : String , data : Data ): Unit = {
157
+ val sqlContext = new SQLContext (sc)
158
+ import sqlContext .implicits ._
159
+
160
+ val metadata = compact(render(
161
+ (" class" -> thisClassName) ~ (" version" -> thisFormatVersion)))
162
+ sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
163
+
164
+ val dataRDD : DataFrame = sc.parallelize(Seq (data), 1 ).toDF()
165
+ dataRDD.saveAsParquetFile(dataPath(path))
166
+ }
167
+
168
+ def load (sc : SparkContext , path : String ): IsotonicRegressionModel = {
169
+ val sqlContext = new SQLContext (sc)
170
+ val dataRDD = sqlContext.parquetFile(dataPath(path))
171
+
172
+ checkSchema[Data ](dataRDD.schema)
173
+ val dataArray = dataRDD.select(" boundaries" , " predictions" , " isotonic" ).take(1 )
174
+ assert(dataArray.size == 1 ,
175
+ s " Unable to load IsotonicRegressionModel data from: ${dataPath(path)}" )
176
+ val data = dataArray(0 )
177
+ val boundaries = data.getAs[Seq [Double ]](0 ).toArray
178
+ val predictions = data.getAs[Seq [Double ]](1 ).toArray
179
+ val isotonic = data.getAs[Boolean ](2 )
180
+ new IsotonicRegressionModel (boundaries, predictions, isotonic)
181
+ }
182
+ }
183
+
184
+ override def load (sc : SparkContext , path : String ): IsotonicRegressionModel = {
185
+ val (loadedClassName, version, metadata) = loadMetadata(sc, path)
186
+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
187
+ (loadedClassName, version) match {
188
+ case (className, " 1.0" ) if className == classNameV1_0 =>
189
+ val model = SaveLoadV1_0 .load(sc, path)
190
+ model
191
+ case _ => throw new Exception (
192
+ s " IsotonicRegressionModel.load did not recognize model with (className, format version): " +
193
+ s " ( $loadedClassName, $version). Supported: \n " +
194
+ s " ( $classNameV1_0, 1.0) "
195
+ )
196
+ }
197
+ }
127
198
}
128
199
129
200
/**
0 commit comments