@@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
32
32
# ' @export
33
33
setClass ("AFTSurvivalRegressionModel ", representation(jobj = "jobj"))
34
34
35
+ # ' @title S4 class that represents a KMeansModel
36
+ # ' @param jobj a Java object reference to the backing Scala KMeansModel
37
+ # ' @export
38
+ setClass ("KMeansModel ", representation(jobj = "jobj"))
39
+
35
40
# ' Fits a generalized linear model
36
41
# '
37
42
# ' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
154
159
colnames(coefficients ) <- c(" Estimate" )
155
160
rownames(coefficients ) <- unlist(features )
156
161
return (list (coefficients = coefficients ))
157
- } else if (modelName == " KMeansModel" ) {
158
- modelSize <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
159
- " getKMeansModelSize" , object @ model )
160
- cluster <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
161
- " getKMeansCluster" , object @ model , " classes" )
162
- k <- unlist(modelSize )[1 ]
163
- size <- unlist(modelSize )[- 1 ]
164
- coefficients <- t(matrix (coefficients , ncol = k ))
165
- colnames(coefficients ) <- unlist(features )
166
- rownames(coefficients ) <- 1 : k
167
- return (list (coefficients = coefficients , size = size , cluster = dataFrame(cluster )))
168
162
} else {
169
163
stop(paste(" Unsupported model" , modelName , sep = " " ))
170
164
}
@@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
213
207
# ' @examples
214
208
# ' \dontrun{
215
209
# ' model <- kmeans(x, centers = 2, algorithm="random")
216
- # '}
210
+ # ' }
217
211
setMethod ("kmeans ", signature(x = "DataFrame"),
218
212
function (x , centers , iter.max = 10 , algorithm = c(" random" , " k-means||" )) {
219
213
columnNames <- as.array(colnames(x ))
220
214
algorithm <- match.arg(algorithm )
221
- model <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers " , " fitKMeans " , x @ sdf ,
222
- algorithm , iter.max , centers , columnNames )
223
- return (new(" PipelineModel " , model = model ))
215
+ jobj <- callJStatic(" org.apache.spark.ml.r.KMeansWrapper " , " fit " , x @ sdf ,
216
+ centers , iter.max , algorithm , columnNames )
217
+ return (new(" KMeansModel " , jobj = jobj ))
224
218
})
225
219
226
- # ' Get fitted result from a model
220
+ # ' Get fitted result from a k-means model
227
221
# '
228
- # ' Get fitted result from a model, similarly to R's fitted().
222
+ # ' Get fitted result from a k-means model, similarly to R's fitted().
229
223
# '
230
- # ' @param object A fitted MLlib model
224
+ # ' @param object A fitted k-means model
231
225
# ' @return DataFrame containing fitted values
232
226
# ' @rdname fitted
233
227
# ' @export
@@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
237
231
# ' fitted.model <- fitted(model)
238
232
# ' showDF(fitted.model)
239
233
# '}
240
- setMethod ("fitted ", signature(object = "PipelineModel "),
234
+ setMethod ("fitted ", signature(object = "KMeansModel "),
241
235
function (object , method = c(" centers" , " classes" ), ... ) {
242
- modelName <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
243
- " getModelName" , object @ model )
236
+ method <- match.arg(method )
237
+ return (dataFrame(callJMethod(object @ jobj , " fitted" , method )))
238
+ })
244
239
245
- if (modelName == " KMeansModel" ) {
246
- method <- match.arg(method )
247
- fittedResult <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
248
- " getKMeansCluster" , object @ model , method )
249
- return (dataFrame(fittedResult ))
250
- } else {
251
- stop(paste(" Unsupported model" , modelName , sep = " " ))
252
- }
240
+ # ' Get the summary of a k-means model
241
+ # '
242
+ # ' Returns the summary of a k-means model produced by kmeans(),
243
+ # ' similarly to R's summary().
244
+ # '
245
+ # ' @param object a fitted k-means model
246
+ # ' @return the model's coefficients, size and cluster
247
+ # ' @rdname summary
248
+ # ' @export
249
+ # ' @examples
250
+ # ' \dontrun{
251
+ # ' model <- kmeans(trainingData, 2)
252
+ # ' summary(model)
253
+ # ' }
254
+ setMethod ("summary ", signature(object = "KMeansModel"),
255
+ function (object , ... ) {
256
+ jobj <- object @ jobj
257
+ features <- callJMethod(jobj , " features" )
258
+ coefficients <- callJMethod(jobj , " coefficients" )
259
+ cluster <- callJMethod(jobj , " cluster" )
260
+ k <- callJMethod(jobj , " k" )
261
+ size <- callJMethod(jobj , " size" )
262
+ coefficients <- t(matrix (coefficients , ncol = k ))
263
+ colnames(coefficients ) <- unlist(features )
264
+ rownames(coefficients ) <- 1 : k
265
+ return (list (coefficients = coefficients , size = size , cluster = dataFrame(cluster )))
266
+ })
267
+
268
+ # ' Make predictions from a k-means model
269
+ # '
270
+ # ' Make predictions from a model produced by kmeans().
271
+ # '
272
+ # ' @param object A fitted k-means model
273
+ # ' @param newData DataFrame for testing
274
+ # ' @return DataFrame containing predicted labels in a column named "prediction"
275
+ # ' @rdname predict
276
+ # ' @export
277
+ # ' @examples
278
+ # ' \dontrun{
279
+ # ' model <- kmeans(trainingData, 2)
280
+ # ' predicted <- predict(model, testData)
281
+ # ' showDF(predicted)
282
+ # ' }
283
+ setMethod ("predict ", signature(object = "KMeansModel"),
284
+ function (object , newData ) {
285
+ return (dataFrame(callJMethod(object @ jobj , " transform" , newData @ sdf )))
253
286
})
254
287
255
288
# ' Fit a Bernoulli naive Bayes model
0 commit comments