Skip to content

Commit 2d4e00e

Browse files
yinxusenmengxr
authored andcommitted
[SPARK-5986][MLLib] Add save/load for k-means
This PR adds save/load for K-means as described in SPARK-5986. Python version will be added in another PR. Author: Xusen Yin <yinxusen@gmail.com> Closes #4951 from yinxusen/SPARK-5986 and squashes the following commits: 6dd74a0 [Xusen Yin] rewrite some functions and classes cd390fd [Xusen Yin] add indexed point b144216 [Xusen Yin] remove invalid comments dce7055 [Xusen Yin] add save/load for k-means for SPARK-5986
1 parent 2672374 commit 2d4e00e

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,22 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import org.json4s._
21+
import org.json4s.JsonDSL._
22+
import org.json4s.jackson.JsonMethods._
23+
2024
import org.apache.spark.api.java.JavaRDD
21-
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.SparkContext._
2325
import org.apache.spark.mllib.linalg.Vector
26+
import org.apache.spark.mllib.util.{Loader, Saveable}
27+
import org.apache.spark.rdd.RDD
28+
import org.apache.spark.SparkContext
29+
import org.apache.spark.sql.SQLContext
30+
import org.apache.spark.sql.Row
2431

2532
/**
2633
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
2734
*/
28-
class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
35+
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
2936

3037
/** Total number of clusters. */
3138
def k: Int = clusterCenters.length
@@ -58,4 +65,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
5865

5966
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
6067
clusterCenters.map(new VectorWithNorm(_))
68+
69+
override def save(sc: SparkContext, path: String): Unit = {
70+
KMeansModel.SaveLoadV1_0.save(sc, this, path)
71+
}
72+
73+
override protected def formatVersion: String = "1.0"
74+
}
75+
76+
object KMeansModel extends Loader[KMeansModel] {
77+
override def load(sc: SparkContext, path: String): KMeansModel = {
78+
KMeansModel.SaveLoadV1_0.load(sc, path)
79+
}
80+
81+
private case class Cluster(id: Int, point: Vector)
82+
83+
private object Cluster {
84+
def apply(r: Row): Cluster = {
85+
Cluster(r.getInt(0), r.getAs[Vector](1))
86+
}
87+
}
88+
89+
private[clustering]
90+
object SaveLoadV1_0 {
91+
92+
private val thisFormatVersion = "1.0"
93+
94+
private[clustering]
95+
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
96+
97+
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
98+
val sqlContext = new SQLContext(sc)
99+
import sqlContext.implicits._
100+
val metadata = compact(render(
101+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
102+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
103+
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
104+
Cluster(id, point)
105+
}.toDF()
106+
dataRDD.saveAsParquetFile(Loader.dataPath(path))
107+
}
108+
109+
def load(sc: SparkContext, path: String): KMeansModel = {
110+
implicit val formats = DefaultFormats
111+
val sqlContext = new SQLContext(sc)
112+
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
113+
assert(className == thisClassName)
114+
assert(formatVersion == thisFormatVersion)
115+
val k = (metadata \ "k").extract[Int]
116+
val centriods = sqlContext.parquetFile(Loader.dataPath(path))
117+
Loader.checkSchema[Cluster](centriods.schema)
118+
val localCentriods = centriods.map(Cluster.apply).collect()
119+
assert(k == localCentriods.size)
120+
new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
121+
}
122+
}
61123
}

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import scala.util.Random
2121

2222
import org.scalatest.FunSuite
2323

24-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
24+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2525
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
2626
import org.apache.spark.mllib.util.TestingUtils._
27+
import org.apache.spark.util.Utils
2728

2829
class KMeansSuite extends FunSuite with MLlibTestSparkContext {
2930

@@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
257258
assert(predicts(0) != predicts(3))
258259
}
259260
}
261+
262+
test("model save/load") {
263+
val tempDir = Utils.createTempDir()
264+
val path = tempDir.toURI.toString
265+
266+
Array(true, false).foreach { case selector =>
267+
val model = KMeansSuite.createModel(10, 3, selector)
268+
// Save model, load it back, and compare.
269+
try {
270+
model.save(sc, path)
271+
val sameModel = KMeansModel.load(sc, path)
272+
KMeansSuite.checkEqual(model, sameModel)
273+
} finally {
274+
Utils.deleteRecursively(tempDir)
275+
}
276+
}
277+
}
278+
}
279+
280+
object KMeansSuite extends FunSuite {
281+
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
282+
val singlePoint = isSparse match {
283+
case true =>
284+
Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
285+
case _ =>
286+
Vectors.dense(Array.fill[Double](dim)(0.0))
287+
}
288+
new KMeansModel(Array.fill[Vector](k)(singlePoint))
289+
}
290+
291+
def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
292+
assert(a.k === b.k)
293+
a.clusterCenters.zip(b.clusterCenters).foreach {
294+
case (ca: SparseVector, cb: SparseVector) =>
295+
assert(ca === cb)
296+
case (ca: DenseVector, cb: DenseVector) =>
297+
assert(ca === cb)
298+
case _ =>
299+
throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
300+
}
301+
}
260302
}
261303

262304
class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {

0 commit comments

Comments
 (0)