Skip to content

Commit 22249af

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for SparkR::kmeans
## What changes were proposed in this pull request? Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper. ## How was this patch tested? Existing tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12039 from yanboliang/spark-14059.
1 parent 26867eb commit 22249af

File tree

3 files changed

+148
-80
lines changed

3 files changed

+148
-80
lines changed

R/pkg/R/mllib.R

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
3232
#' @export
3333
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
3434

35+
#' @title S4 class that represents a KMeansModel
36+
#' @param jobj a Java object reference to the backing Scala KMeansModel
37+
#' @export
38+
setClass("KMeansModel", representation(jobj = "jobj"))
39+
3540
#' Fits a generalized linear model
3641
#'
3742
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
154159
colnames(coefficients) <- c("Estimate")
155160
rownames(coefficients) <- unlist(features)
156161
return(list(coefficients = coefficients))
157-
} else if (modelName == "KMeansModel") {
158-
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
159-
"getKMeansModelSize", object@model)
160-
cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
161-
"getKMeansCluster", object@model, "classes")
162-
k <- unlist(modelSize)[1]
163-
size <- unlist(modelSize)[-1]
164-
coefficients <- t(matrix(coefficients, ncol = k))
165-
colnames(coefficients) <- unlist(features)
166-
rownames(coefficients) <- 1:k
167-
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
168162
} else {
169163
stop(paste("Unsupported model", modelName, sep = " "))
170164
}
@@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
213207
#' @examples
214208
#' \dontrun{
215209
#' model <- kmeans(x, centers = 2, algorithm="random")
216-
#'}
210+
#' }
217211
setMethod("kmeans", signature(x = "DataFrame"),
218212
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
219213
columnNames <- as.array(colnames(x))
220214
algorithm <- match.arg(algorithm)
221-
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
222-
algorithm, iter.max, centers, columnNames)
223-
return(new("PipelineModel", model = model))
215+
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
216+
centers, iter.max, algorithm, columnNames)
217+
return(new("KMeansModel", jobj = jobj))
224218
})
225219

226-
#' Get fitted result from a model
220+
#' Get fitted result from a k-means model
227221
#'
228-
#' Get fitted result from a model, similarly to R's fitted().
222+
#' Get fitted result from a k-means model, similarly to R's fitted().
229223
#'
230-
#' @param object A fitted MLlib model
224+
#' @param object A fitted k-means model
231225
#' @return DataFrame containing fitted values
232226
#' @rdname fitted
233227
#' @export
@@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
237231
#' fitted.model <- fitted(model)
238232
#' showDF(fitted.model)
239233
#'}
240-
setMethod("fitted", signature(object = "PipelineModel"),
234+
setMethod("fitted", signature(object = "KMeansModel"),
241235
function(object, method = c("centers", "classes"), ...) {
242-
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
243-
"getModelName", object@model)
236+
method <- match.arg(method)
237+
return(dataFrame(callJMethod(object@jobj, "fitted", method)))
238+
})
244239

245-
if (modelName == "KMeansModel") {
246-
method <- match.arg(method)
247-
fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
248-
"getKMeansCluster", object@model, method)
249-
return(dataFrame(fittedResult))
250-
} else {
251-
stop(paste("Unsupported model", modelName, sep = " "))
252-
}
240+
#' Get the summary of a k-means model
241+
#'
242+
#' Returns the summary of a k-means model produced by kmeans(),
243+
#' similarly to R's summary().
244+
#'
245+
#' @param object a fitted k-means model
246+
#' @return the model's coefficients, size and cluster
247+
#' @rdname summary
248+
#' @export
249+
#' @examples
250+
#' \dontrun{
251+
#' model <- kmeans(trainingData, 2)
252+
#' summary(model)
253+
#' }
254+
setMethod("summary", signature(object = "KMeansModel"),
255+
function(object, ...) {
256+
jobj <- object@jobj
257+
features <- callJMethod(jobj, "features")
258+
coefficients <- callJMethod(jobj, "coefficients")
259+
cluster <- callJMethod(jobj, "cluster")
260+
k <- callJMethod(jobj, "k")
261+
size <- callJMethod(jobj, "size")
262+
coefficients <- t(matrix(coefficients, ncol = k))
263+
colnames(coefficients) <- unlist(features)
264+
rownames(coefficients) <- 1:k
265+
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
266+
})
267+
268+
#' Make predictions from a k-means model
269+
#'
270+
#' Make predictions from a model produced by kmeans().
271+
#'
272+
#' @param object A fitted k-means model
273+
#' @param newData DataFrame for testing
274+
#' @return DataFrame containing predicted labels in a column named "prediction"
275+
#' @rdname predict
276+
#' @export
277+
#' @examples
278+
#' \dontrun{
279+
#' model <- kmeans(trainingData, 2)
280+
#' predicted <- predict(model, testData)
281+
#' showDF(predicted)
282+
#' }
283+
setMethod("predict", signature(object = "KMeansModel"),
284+
function(object, newData) {
285+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
253286
})
254287

255288
#' Fit a Bernoulli naive Bayes model
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.spark.ml.{Pipeline, PipelineModel}
21+
import org.apache.spark.ml.attribute.AttributeGroup
22+
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
23+
import org.apache.spark.ml.feature.VectorAssembler
24+
import org.apache.spark.sql.DataFrame
25+
26+
private[r] class KMeansWrapper private (
27+
pipeline: PipelineModel) {
28+
29+
private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
30+
31+
lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
32+
33+
private lazy val attrs = AttributeGroup.fromStructField(
34+
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
35+
36+
lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
37+
38+
lazy val k: Int = kMeansModel.getK
39+
40+
lazy val size: Array[Int] = kMeansModel.summary.size
41+
42+
lazy val cluster: DataFrame = kMeansModel.summary.cluster
43+
44+
def fitted(method: String): DataFrame = {
45+
if (method == "centers") {
46+
kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
47+
} else if (method == "classes") {
48+
kMeansModel.summary.cluster
49+
} else {
50+
throw new UnsupportedOperationException(
51+
s"Method (centers or classes) required but $method found.")
52+
}
53+
}
54+
55+
def transform(dataset: DataFrame): DataFrame = {
56+
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
57+
}
58+
59+
}
60+
61+
private[r] object KMeansWrapper {
62+
63+
def fit(
64+
data: DataFrame,
65+
k: Double,
66+
maxIter: Double,
67+
initMode: String,
68+
columns: Array[String]): KMeansWrapper = {
69+
70+
val assembler = new VectorAssembler()
71+
.setInputCols(columns)
72+
.setOutputCol("features")
73+
74+
val kMeans = new KMeans()
75+
.setK(k.toInt)
76+
.setMaxIter(maxIter.toInt)
77+
.setInitMode(initMode)
78+
79+
val pipeline = new Pipeline()
80+
.setStages(Array(assembler, kMeans))
81+
.fit(data)
82+
83+
new KMeansWrapper(pipeline)
84+
}
85+
}

mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r
2020
import org.apache.spark.ml.{Pipeline, PipelineModel}
2121
import org.apache.spark.ml.attribute._
2222
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
23-
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
24-
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
23+
import org.apache.spark.ml.feature.RFormula
2524
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
2625
import org.apache.spark.sql.DataFrame
2726

@@ -52,22 +51,6 @@ private[r] object SparkRWrappers {
5251
pipeline.fit(df)
5352
}
5453

55-
def fitKMeans(
56-
df: DataFrame,
57-
initMode: String,
58-
maxIter: Double,
59-
k: Double,
60-
columns: Array[String]): PipelineModel = {
61-
val assembler = new VectorAssembler().setInputCols(columns)
62-
val kMeans = new KMeans()
63-
.setInitMode(initMode)
64-
.setMaxIter(maxIter.toInt)
65-
.setK(k.toInt)
66-
.setFeaturesCol(assembler.getOutputCol)
67-
val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
68-
pipeline.fit(df)
69-
}
70-
7154
def getModelCoefficients(model: PipelineModel): Array[Double] = {
7255
model.stages.last match {
7356
case m: LinearRegressionModel => {
@@ -89,8 +72,6 @@ private[r] object SparkRWrappers {
8972
m.coefficients.toArray
9073
}
9174
}
92-
case m: KMeansModel =>
93-
m.clusterCenters.flatMap(_.toArray)
9475
}
9576
}
9677

@@ -104,31 +85,6 @@ private[r] object SparkRWrappers {
10485
}
10586
}
10687

107-
def getKMeansModelSize(model: PipelineModel): Array[Int] = {
108-
model.stages.last match {
109-
case m: KMeansModel => Array(m.getK) ++ m.summary.size
110-
case other => throw new UnsupportedOperationException(
111-
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
112-
}
113-
}
114-
115-
def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
116-
model.stages.last match {
117-
case m: KMeansModel =>
118-
if (method == "centers") {
119-
// Drop the assembled vector for easy-print to R side.
120-
m.summary.predictions.drop(m.summary.featuresCol)
121-
} else if (method == "classes") {
122-
m.summary.cluster
123-
} else {
124-
throw new UnsupportedOperationException(
125-
s"Method (centers or classes) required but $method found.")
126-
}
127-
case other => throw new UnsupportedOperationException(
128-
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
129-
}
130-
}
131-
13288
def getModelFeatures(model: PipelineModel): Array[String] = {
13389
model.stages.last match {
13490
case m: LinearRegressionModel =>
@@ -147,10 +103,6 @@ private[r] object SparkRWrappers {
147103
} else {
148104
attrs.attributes.get.map(_.name.get)
149105
}
150-
case m: KMeansModel =>
151-
val attrs = AttributeGroup.fromStructField(
152-
m.summary.predictions.schema(m.summary.featuresCol))
153-
attrs.attributes.get.map(_.name.get)
154106
}
155107
}
156108

@@ -160,8 +112,6 @@ private[r] object SparkRWrappers {
160112
"LinearRegressionModel"
161113
case m: LogisticRegressionModel =>
162114
"LogisticRegressionModel"
163-
case m: KMeansModel =>
164-
"KMeansModel"
165115
}
166116
}
167117
}

0 commit comments

Comments
 (0)