Skip to content

Commit 6c42d61

Browse files
imatiach-msftcmonkey
authored andcommitted
[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case
[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case where no children exist in updateAssignments ## What changes were proposed in this pull request? Fix a bug in which BisectingKMeans fails with error: java.util.NoSuchElementException: key not found: 166 at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:58) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:58) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply$mcDJ$sp(BisectingKMeans.scala:338) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply(BisectingKMeans.scala:337) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply(BisectingKMeans.scala:337) at scala.collection.TraversableOnce$$anonfun$minBy$1.apply(TraversableOnce.scala:231) at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:111) at scala.collection.immutable.List.foldLeft(List.scala:84) at scala.collection.LinearSeqOptimized$class.reduceLeft(LinearSeqOptimized.scala:125) at scala.collection.immutable.List.reduceLeft(List.scala:84) at scala.collection.TraversableOnce$class.minBy(TraversableOnce.scala:231) at scala.collection.AbstractTraversable.minBy(Traversable.scala:105) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1.apply(BisectingKMeans.scala:337) at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1.apply(BisectingKMeans.scala:334) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$14.hasNext(Iterator.scala:389) ## How was this patch tested? The dataset was run against the code change to verify that the code works. I will try to add unit tests to the code. (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ilya Matiach <ilmat@microsoft.com> Closes apache#16355 from imatiach-msft/ilmat/fix-kmeans.
1 parent 82ab33f commit 6c42d61

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
339339
assignments.map { case (index, v) =>
340340
if (divisibleIndices.contains(index)) {
341341
val children = Seq(leftChildIndex(index), rightChildIndex(index))
342-
val selected = children.minBy { child =>
343-
KMeans.fastSquaredDistance(newClusterCenters(child), v)
342+
val newClusterChildren = children.filter(newClusterCenters.contains(_))
343+
if (newClusterChildren.nonEmpty) {
344+
val selected = newClusterChildren.minBy { child =>
345+
KMeans.fastSquaredDistance(newClusterCenters(child), v)
346+
}
347+
(selected, v)
348+
} else {
349+
(index, v)
344350
}
345-
(selected, v)
346351
} else {
347352
(index, v)
348353
}
@@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
372377
internalIndex -= 1
373378
val leftIndex = leftChildIndex(rawIndex)
374379
val rightIndex = rightChildIndex(rawIndex)
375-
val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
380+
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
381+
val height = math.sqrt(indexes.map { childIndex =>
376382
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
377383
}.max)
378-
val left = buildSubTree(leftIndex)
379-
val right = buildSubTree(rightIndex)
380-
new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
384+
val children = indexes.map(buildSubTree(_)).toArray
385+
new ClusteringTreeNode(index, size, center, cost, height, children)
381386
} else {
382387
val index = leafIndex
383388
leafIndex += 1

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ class BisectingKMeansSuite
2929
final val k = 5
3030
@transient var dataset: Dataset[_] = _
3131

32+
@transient var sparseDataset: Dataset[_] = _
33+
3234
override def beforeAll(): Unit = {
3335
super.beforeAll()
3436
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
37+
sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
3538
}
3639

3740
test("default parameters") {
@@ -51,6 +54,22 @@ class BisectingKMeansSuite
5154
assert(copiedModel.hasSummary)
5255
}
5356

57+
test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
58+
"one cluster is empty after split") {
59+
val bkm = new BisectingKMeans()
60+
.setK(k)
61+
.setMinDivisibleClusterSize(4)
62+
.setMaxIter(4)
63+
.setSeed(123)
64+
65+
// Verify fit does not fail on very sparse data
66+
val model = bkm.fit(sparseDataset)
67+
val result = model.transform(sparseDataset)
68+
val numClusters = result.select("prediction").distinct().collect().length
69+
// Verify we hit the edge case
70+
assert(numClusters < k && numClusters > 1)
71+
}
72+
5473
test("setter/getter") {
5574
val bkm = new BisectingKMeans()
5675
.setK(9)

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import scala.util.Random
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.ml.linalg.{Vector, Vectors}
2224
import org.apache.spark.ml.param.ParamMap
@@ -160,6 +162,17 @@ object KMeansSuite {
160162
spark.createDataFrame(rdd)
161163
}
162164

165+
def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
166+
val sc = spark.sparkContext
167+
val random = new Random(seed)
168+
val nnz = random.nextInt(dim)
169+
val rdd = sc.parallelize(1 to rows)
170+
.map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
171+
Array.fill(nnz)(random.nextDouble())))
172+
.map(v => new TestRow(v))
173+
spark.createDataFrame(rdd)
174+
}
175+
163176
/**
164177
* Mapping from all Params to valid settings which differ from the defaults.
165178
* This is useful for tests which need to exercise all Params, such as save/load.

0 commit comments

Comments
 (0)