Skip to content

Commit b0415e8

Browse files
committed
replace SQL JSON usage by json4s
1 parent fa6bdc6 commit b0415e8

15 files changed

+92
-127
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20+
import org.json4s.{DefaultFormats, JValue}
21+
2022
import org.apache.spark.annotation.Experimental
2123
import org.apache.spark.api.java.JavaRDD
2224
import org.apache.spark.mllib.linalg.Vector
23-
import org.apache.spark.mllib.util.Loader
2425
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.{DataFrame, Row}
2626

2727
/**
2828
* :: Experimental ::
@@ -60,16 +60,10 @@ private[mllib] object ClassificationModel {
6060

6161
/**
6262
* Helper method for loading GLM classification model metadata.
63-
*
64-
* @param modelClass String name for model class (used for error messages)
6563
* @return (numFeatures, numClasses)
6664
*/
67-
def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
68-
metadata.select("numFeatures", "numClasses").take(1)(0) match {
69-
case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
70-
case _ => throw new Exception(s"$modelClass unable to load" +
71-
s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
72-
}
65+
def getNumFeaturesClasses(metadata: JValue): (Int, Int) = {
66+
implicit val formats = DefaultFormats
67+
((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int])
7368
}
74-
7569
}

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
173173
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
174174
(loadedClassName, version) match {
175175
case (className, "1.0") if className == classNameV1_0 =>
176-
val (numFeatures, numClasses) =
177-
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
176+
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
178177
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
179178
// numFeatures, numClasses, weights are checked in model initialization
180179
val model =

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
package org.apache.spark.mllib.classification
1919

2020
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
21+
import org.json4s.JsonDSL._
22+
import org.json4s.jackson.JsonMethods._
2123

22-
import org.apache.spark.{SparkContext, SparkException, Logging}
24+
import org.apache.spark.{Logging, SparkContext, SparkException}
2325
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
2426
import org.apache.spark.mllib.regression.LabeledPoint
2527
import org.apache.spark.mllib.util.{Loader, Saveable}
2628
import org.apache.spark.rdd.RDD
2729
import org.apache.spark.sql.{DataFrame, SQLContext}
2830

29-
3031
/**
3132
* Model for Naive Bayes Classifiers.
3233
*
@@ -78,7 +79,7 @@ class NaiveBayesModel private[mllib] (
7879

7980
object NaiveBayesModel extends Loader[NaiveBayesModel] {
8081

81-
import Loader._
82+
import org.apache.spark.mllib.util.Loader._
8283

8384
private object SaveLoadV1_0 {
8485

@@ -95,10 +96,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
9596
import sqlContext.implicits._
9697

9798
// Create JSON metadata.
98-
val metadataRDD =
99-
sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
100-
.toDataFrame("class", "version", "numFeatures", "numClasses")
101-
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
99+
val metadata = compact(render(
100+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
101+
("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
102+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
102103

103104
// Create Parquet data.
104105
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
@@ -126,8 +127,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
126127
val classNameV1_0 = SaveLoadV1_0.thisClassName
127128
(loadedClassName, version) match {
128129
case (className, "1.0") if className == classNameV1_0 =>
129-
val (numFeatures, numClasses) =
130-
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
130+
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
131131
val model = SaveLoadV1_0.load(sc, path)
132132
assert(model.pi.size == numClasses,
133133
s"NaiveBayesModel.load expected $numClasses classes," +

mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
2323
import org.apache.spark.mllib.linalg.Vector
2424
import org.apache.spark.mllib.optimization._
2525
import org.apache.spark.mllib.regression._
26-
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
26+
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
2727
import org.apache.spark.rdd.RDD
2828

29-
3029
/**
3130
* Model for Support Vector Machines (SVMs).
3231
*
@@ -97,8 +96,7 @@ object SVMModel extends Loader[SVMModel] {
9796
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
9897
(loadedClassName, version) match {
9998
case (className, "1.0") if className == classNameV1_0 =>
100-
val (numFeatures, numClasses) =
101-
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
99+
val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
102100
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
103101
val model = new SVMModel(data.weights, data.intercept)
104102
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +

mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.mllib.classification.impl
1919

20+
import org.json4s.JsonDSL._
21+
import org.json4s.jackson.JsonMethods._
22+
2023
import org.apache.spark.SparkContext
2124
import org.apache.spark.mllib.linalg.Vector
2225
import org.apache.spark.mllib.util.Loader
23-
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
26+
import org.apache.spark.sql.{Row, SQLContext}
2427

2528
/**
2629
* Helper class for import/export of GLM classification models.
@@ -52,16 +55,14 @@ private[classification] object GLMClassificationModel {
5255
import sqlContext.implicits._
5356

5457
// Create JSON metadata.
55-
val metadataRDD =
56-
sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
57-
.toDataFrame("class", "version", "numFeatures", "numClasses")
58-
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
58+
val metadata = compact(render(
59+
("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
60+
("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses)))
61+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
5962

6063
// Create Parquet data.
6164
val data = Data(weights, intercept, threshold)
62-
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
63-
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
64-
dataRDD.saveAsParquetFile(Loader.dataPath(path))
65+
sc.parallelize(Seq(data), 1).saveAsParquetFile(Loader.dataPath(path))
6566
}
6667

6768
/**

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import java.lang.{Integer => JavaInteger}
2222

2323
import org.apache.hadoop.fs.Path
2424
import org.jblas.DoubleMatrix
25+
import org.json4s._
26+
import org.json4s.JsonDSL._
27+
import org.json4s.jackson.JsonMethods._
2528

2629
import org.apache.spark.{Logging, SparkContext}
2730
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
@@ -153,7 +156,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
153156
import org.apache.spark.mllib.util.Loader._
154157

155158
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
156-
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
159+
val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
157160
val classNameV1_0 = SaveLoadV1_0.thisClassName
158161
(loadedClassName, formatVersion) match {
159162
case (className, "1.0") if className == classNameV1_0 =>
@@ -181,19 +184,20 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
181184
val sc = model.userFeatures.sparkContext
182185
val sqlContext = new SQLContext(sc)
183186
import sqlContext.implicits._
184-
val metadata = (thisClassName, thisFormatVersion, model.rank)
185-
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
186-
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
187+
val metadata = compact(render(
188+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
189+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
187190
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
188191
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
189192
}
190193

191194
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
195+
implicit val formats = DefaultFormats
192196
val sqlContext = new SQLContext(sc)
193197
val (className, formatVersion, metadata) = loadMetadata(sc, path)
194198
assert(className == thisClassName)
195199
assert(formatVersion == thisFormatVersion)
196-
val rank = metadata.select("rank").first().getInt(0)
200+
val rank = (metadata \ "rank").extract[Int]
197201
val userFeatures = sqlContext.parquetFile(userPath(path))
198202
.map { case Row(id: Int, features: Seq[Double]) =>
199203
(id, features.toArray)

mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ object LassoModel extends Loader[LassoModel] {
5858
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
5959
(loadedClassName, version) match {
6060
case (className, "1.0") if className == classNameV1_0 =>
61-
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
61+
val numFeatures = RegressionModel.getNumFeatures(metadata)
6262
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
6363
new LassoModel(data.weights, data.intercept)
6464
case _ => throw new Exception(

mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {
5858
val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
5959
(loadedClassName, version) match {
6060
case (className, "1.0") if className == classNameV1_0 =>
61-
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
61+
val numFeatures = RegressionModel.getNumFeatures(metadata)
6262
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
6363
new LinearRegressionModel(data.weights, data.intercept)
6464
case _ => throw new Exception(

mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20+
import org.json4s.{DefaultFormats, JValue}
21+
2022
import org.apache.spark.annotation.Experimental
2123
import org.apache.spark.api.java.JavaRDD
2224
import org.apache.spark.mllib.linalg.Vector
23-
import org.apache.spark.mllib.util.Loader
2425
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.{DataFrame, Row}
2626

2727
@Experimental
2828
trait RegressionModel extends Serializable {
@@ -55,16 +55,10 @@ private[mllib] object RegressionModel {
5555

5656
/**
5757
* Helper method for loading GLM regression model metadata.
58-
*
59-
* @param modelClass String name for model class (used for error messages)
6058
* @return numFeatures
6159
*/
62-
def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
63-
metadata.select("numFeatures").take(1)(0) match {
64-
case Row(nFeatures: Int) => nFeatures
65-
case _ => throw new Exception(s"$modelClass unable to load" +
66-
s" numFeatures from metadata: ${Loader.metadataPath(path)}")
67-
}
60+
def getNumFeatures(metadata: JValue): Int = {
61+
implicit val formats = DefaultFormats
62+
(metadata \ "numFeatures").extract[Int]
6863
}
69-
7064
}

mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
5959
val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
6060
(loadedClassName, version) match {
6161
case (className, "1.0") if className == classNameV1_0 =>
62-
val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
62+
val numFeatures = RegressionModel.getNumFeatures(metadata)
6363
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
6464
new RidgeRegressionModel(data.weights, data.intercept)
6565
case _ => throw new Exception(

mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.mllib.regression.impl
1919

20+
import org.json4s.JsonDSL._
21+
import org.json4s.jackson.JsonMethods._
22+
2023
import org.apache.spark.SparkContext
2124
import org.apache.spark.mllib.linalg.Vector
2225
import org.apache.spark.mllib.util.Loader
@@ -48,10 +51,10 @@ private[regression] object GLMRegressionModel {
4851
import sqlContext.implicits._
4952

5053
// Create JSON metadata.
51-
val metadataRDD =
52-
sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
53-
.toDataFrame("class", "version", "numFeatures")
54-
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
54+
val metadata = compact(render(
55+
("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
56+
("numFeatures" -> weights.size)))
57+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
5558

5659
// Create Parquet data.
5760
val data = Data(weights, intercept)

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,24 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import scala.collection.JavaConverters._
2120
import scala.collection.mutable
21+
import scala.collection.JavaConverters._
2222
import scala.collection.mutable.ArrayBuffer
2323

24-
24+
import org.apache.spark.Logging
2525
import org.apache.spark.annotation.Experimental
2626
import org.apache.spark.api.java.JavaRDD
27-
import org.apache.spark.Logging
2827
import org.apache.spark.mllib.regression.LabeledPoint
2928
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
3029
import org.apache.spark.mllib.tree.configuration.Strategy
3130
import org.apache.spark.mllib.tree.configuration.Algo._
3231
import org.apache.spark.mllib.tree.configuration.FeatureType._
3332
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
3433
import org.apache.spark.mllib.tree.impl._
35-
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
3634
import org.apache.spark.mllib.tree.impurity._
3735
import org.apache.spark.mllib.tree.model._
3836
import org.apache.spark.rdd.RDD
3937
import org.apache.spark.util.random.XORShiftRandom
40-
import org.apache.spark.SparkContext._
41-
4238

4339
/**
4440
* :: Experimental ::

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ package org.apache.spark.mllib.tree.model
1919

2020
import scala.collection.mutable
2121

22+
import org.json4s._
23+
import org.json4s.JsonDSL._
24+
import org.json4s.jackson.JsonMethods._
25+
2226
import org.apache.spark.SparkContext
2327
import org.apache.spark.annotation.Experimental
2428
import org.apache.spark.api.java.JavaRDD
@@ -184,10 +188,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
184188
import sqlContext.implicits._
185189

186190
// Create JSON metadata.
187-
val metadataRDD = sc.parallelize(
188-
Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
189-
.toDataFrame("class", "version", "algo", "numNodes")
190-
metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
191+
val metadata = compact(render(
192+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
193+
("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes)))
194+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
191195

192196
// Create Parquet data.
193197
val nodes = model.topNode.subtreeIterator.toSeq
@@ -269,20 +273,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
269273
}
270274

271275
override def load(sc: SparkContext, path: String): DecisionTreeModel = {
276+
implicit val formats = DefaultFormats
272277
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
273-
val (algo: String, numNodes: Int) = try {
274-
val algo_numNodes = metadata.select("algo", "numNodes").collect()
275-
assert(algo_numNodes.length == 1)
276-
algo_numNodes(0) match {
277-
case Row(a: String, n: Int) => (a, n)
278-
}
279-
} catch {
280-
// Catch both Error and Exception since the checks above can throw either.
281-
case e: Throwable =>
282-
throw new Exception(
283-
s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
284-
+ s" Error message: ${e.getMessage}")
285-
}
278+
val algo = (metadata \ "algo").extract[String]
279+
val numNodes = (metadata \ "numNodes").extract[Int]
286280
val classNameV1_0 = SaveLoadV1_0.thisClassName
287281
(loadedClassName, version) match {
288282
case (className, "1.0") if className == classNameV1_0 =>

0 commit comments

Comments
 (0)