Skip to content

Commit c0c902a

Browse files
mgaido91srowen
authored andcommitted
[SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance
## What changes were proposed in this pull request? In apache#19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes apache#20518 from mgaido91/SPARK-22119_followup.
1 parent 4bbd744 commit c0c902a

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,19 +310,17 @@ class KMeans private (
310310
points.foreach { point =>
311311
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
312312
costAccum.add(cost)
313-
val sum = sums(bestCenter)
314-
axpy(1.0, point.vector, sum)
313+
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
315314
counts(bestCenter) += 1
316315
}
317316

318317
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
319318
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
320319
axpy(1.0, sum2, sum1)
321320
(sum1, count1 + count2)
322-
}.mapValues { case (sum, count) =>
323-
scal(1.0 / count, sum)
324-
new VectorWithNorm(sum)
325-
}.collectAsMap()
321+
}.collectAsMap().mapValues { case (sum, count) =>
322+
distanceMeasureInstance.centroid(sum, count)
323+
}
326324

327325
bcCenters.destroy(blocking = false)
328326

@@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable {
657655
v1: VectorWithNorm,
658656
v2: VectorWithNorm): Double
659657

658+
/**
659+
* Updates the value of `sum` adding the `point` vector.
660+
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
661+
* @param sum the `sum` for a cluster to be updated
662+
*/
663+
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
664+
axpy(1.0, point.vector, sum)
665+
}
666+
667+
/**
668+
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
669+
*
670+
* @param sum the `sum` for a cluster
671+
* @param count the number of points in the cluster
672+
* @return the centroid of the cluster
673+
*/
674+
def centroid(sum: Vector, count: Long): VectorWithNorm = {
675+
scal(1.0 / count, sum)
676+
new VectorWithNorm(sum)
677+
}
660678
}
661679

662680
@Since("2.4.0")
@@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
743761
* @return the cosine distance between the two input vectors
744762
*/
745763
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
764+
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
746765
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
747766
}
767+
768+
/**
769+
* Updates the value of `sum` adding the `point` vector.
770+
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
771+
* @param sum the `sum` for a cluster to be updated
772+
*/
773+
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
774+
axpy(1.0 / point.norm, point.vector, sum)
775+
}
776+
777+
/**
778+
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
779+
*
780+
* @param sum the `sum` for a cluster
781+
* @param count the number of points in the cluster
782+
* @return the centroid of the cluster
783+
*/
784+
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
785+
scal(1.0 / count, sum)
786+
val norm = Vectors.norm(sum, 2)
787+
scal(1.0 / norm, sum)
788+
new VectorWithNorm(sum, 1)
789+
}
748790
}

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering
1919

2020
import scala.util.Random
2121

22-
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.{SparkException, SparkFunSuite}
2323
import org.apache.spark.ml.linalg.{Vector, Vectors}
2424
import org.apache.spark.ml.param.ParamMap
2525
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
179179
assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
180180
predictionsMap(Vectors.dense(-100.0, 90.0)))
181181

182+
model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
183+
}
184+
185+
test("KMeans with cosine distance is not supported for 0-length vectors") {
186+
val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
187+
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
188+
Vectors.dense(0.0, 0.0),
189+
Vectors.dense(10.0, 10.0),
190+
Vectors.dense(1.0, 0.5)
191+
)).map(v => TestRow(v)))
192+
val e = intercept[SparkException](model.fit(df))
193+
assert(e.getCause.isInstanceOf[AssertionError])
194+
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
182195
}
183196

184197
test("read/write") {

0 commit comments

Comments
 (0)