Skip to content

[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case #16355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
val selected = children.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
val newClusterChildren = children.filter(newClusterCenters.contains(_))
if (newClusterChildren.nonEmpty) {
val selected = newClusterChildren.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
}
(selected, v)
} else {
(index, v)
}
(selected, v)
} else {
(index, v)
}
Expand Down Expand Up @@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
val left = buildSubTree(leftIndex)
val right = buildSubTree(rightIndex)
new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
val children = indexes.map(buildSubTree(_)).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ class BisectingKMeansSuite
final val k = 5
@transient var dataset: Dataset[_] = _

@transient var sparseDataset: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
}

test("default parameters") {
Expand All @@ -51,6 +54,22 @@ class BisectingKMeansSuite
assert(copiedModel.hasSummary)
}

test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
"one cluster is empty after split") {
val bkm = new BisectingKMeans()
.setK(k)
.setMinDivisibleClusterSize(4)
.setMaxIter(4)
.setSeed(123)

// Verify fit does not fail on very sparse data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that this unit test actually tests the issue fixed in this PR. Is there a good way to see why it would? If not, then it would be great to write a tiny dataset by hand which would trigger the failure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this check to verify:
// Verify we hit the edge case
assert(numClusters < k && numClusters > 1)
the issue only occurs for very sparse data, but it occurs very consistently (almost all very sparse data that I generate can trigger the error)

val model = bkm.fit(sparseDataset)
val result = model.transform(sparseDataset)
val numClusters = result.select("prediction").distinct().collect().length
// Verify we hit the edge case
assert(numClusters < k && numClusters > 1)
}

test("setter/getter") {
val bkm = new BisectingKMeans()
.setK(9)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.clustering

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
Expand Down Expand Up @@ -160,6 +162,17 @@ object KMeansSuite {
spark.createDataFrame(rdd)
}

def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
val sc = spark.sparkContext
val random = new Random(seed)
val nnz = random.nextInt(dim)
val rdd = sc.parallelize(1 to rows)
.map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
Array.fill(nnz)(random.nextDouble())))
.map(v => new TestRow(v))
spark.createDataFrame(rdd)
}

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
Expand Down