Skip to content

Commit 5eb2835

Browse files
committed
backport fix
1 parent e43f161 commit 5eb2835

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

R/pkg/R/mllib.R

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,13 @@ setMethod("fitted", signature(object = "KMeansModel"),
671671

672672
#' @param object a fitted k-means model.
673673
#' @return \code{summary} returns summary information of the fitted model, which is a list.
674-
#' The list includes the model's \code{k} (number of cluster centers),
674+
#' The list includes the model's \code{k} (the configured number of cluster centers),
675675
#' \code{coefficients} (model cluster centers),
676-
#' \code{size} (number of data points in each cluster), and \code{cluster}
677-
#' (cluster centers of the transformed data).
676+
#' \code{size} (number of data points in each cluster), \code{cluster}
677+
#' (cluster centers of the transformed data), {is.loaded} (whether the model is loaded
678+
#' from a saved file), and \code{clusterSize}
679+
#' (the actual number of cluster centers. When using initMode = "random",
680+
#' \code{clusterSize} may not equal to \code{k}).
678681
#' @rdname spark.kmeans
679682
#' @export
680683
#' @note summary(KMeansModel) since 2.0.0
@@ -686,16 +689,17 @@ setMethod("summary", signature(object = "KMeansModel"),
686689
coefficients <- callJMethod(jobj, "coefficients")
687690
k <- callJMethod(jobj, "k")
688691
size <- callJMethod(jobj, "size")
689-
coefficients <- t(matrix(coefficients, ncol = k))
692+
clusterSize <- callJMethod(jobj, "clusterSize")
693+
coefficients <- t(matrix(coefficients, ncol = clusterSize))
690694
colnames(coefficients) <- unlist(features)
691-
rownames(coefficients) <- 1:k
695+
rownames(coefficients) <- 1:clusterSize
692696
cluster <- if (is.loaded) {
693697
NULL
694698
} else {
695699
dataFrame(callJMethod(jobj, "cluster"))
696700
}
697701
list(k = k, coefficients = coefficients, size = size,
698-
cluster = cluster, is.loaded = is.loaded)
702+
cluster = cluster, is.loaded = is.loaded, clusterSize = clusterSize)
699703
})
700704

701705
# Predicted values based on a k-means model

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,33 @@ test_that("spark.kmeans", {
375375
expect_true(summary2$is.loaded)
376376

377377
unlink(modelPath)
378+
379+
# Test Kmeans on dataset that is sensitive to seed value
380+
col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
381+
col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
382+
col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
383+
cols <- as.data.frame(cbind(col1, col2, col3))
384+
df <- createDataFrame(cols)
385+
386+
model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
387+
initMode = "random", seed = 1, tol = 1E-5)
388+
model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
389+
initMode = "random", seed = 22222, tol = 1E-5)
390+
391+
summary.model1 <- summary(model1)
392+
summary.model2 <- summary(model2)
393+
cluster1 <- summary.model1$cluster
394+
cluster2 <- summary.model2$cluster
395+
clusterSize1 <- summary.model1$clusterSize
396+
clusterSize2 <- summary.model2$clusterSize
397+
398+
# The predicted clusters are different
399+
expect_equal(sort(collect(distinct(select(cluster1, "prediction")))$prediction),
400+
c(0, 1, 2, 3))
401+
expect_equal(sort(collect(distinct(select(cluster2, "prediction")))$prediction),
402+
c(0, 1, 2))
403+
expect_equal(clusterSize1, 4)
404+
expect_equal(clusterSize2, 3)
378405
})
379406

380407
test_that("spark.mlp", {

mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ private[r] class KMeansWrapper private (
4343

4444
lazy val cluster: DataFrame = kMeansModel.summary.cluster
4545

46+
lazy val clusterSize: Int = kMeansModel.clusterCenters.size
47+
4648
def fitted(method: String): DataFrame = {
4749
if (method == "centers") {
4850
kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)

0 commit comments

Comments
 (0)