Skip to content

[SPARK-14303] [ML] [SparkR] Define and use KMeansWrapper for SparkR::kmeans #12039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 62 additions & 29 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @export
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))

#' @title S4 class that represents a KMeansModel
#' @param jobj a Java object reference to the backing Scala KMeansModel
#' @export
setClass("KMeansModel", representation(jobj = "jobj"))

#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
Expand Down Expand Up @@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else if (modelName == "KMeansModel") {
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansModelSize", object@model)
cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansCluster", object@model, "classes")
k <- unlist(modelSize)[1]
size <- unlist(modelSize)[-1]
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
Expand Down Expand Up @@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @examples
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#'}
#' }
setMethod("kmeans", signature(x = "DataFrame"),
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
columnNames <- as.array(colnames(x))
algorithm <- match.arg(algorithm)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
algorithm, iter.max, centers, columnNames)
return(new("PipelineModel", model = model))
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
centers, iter.max, algorithm, columnNames)
return(new("KMeansModel", jobj = jobj))
})

#' Get fitted result from a model
#' Get fitted result from a k-means model
#'
#' Get fitted result from a model, similarly to R's fitted().
#' Get fitted result from a k-means model, similarly to R's fitted().
#'
#' @param object A fitted MLlib model
#' @param object A fitted k-means model
#' @return DataFrame containing fitted values
#' @rdname fitted
#' @export
Expand All @@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
#'}
setMethod("fitted", signature(object = "PipelineModel"),
setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", object@model)
method <- match.arg(method)
return(dataFrame(callJMethod(object@jobj, "fitted", method)))
})

if (modelName == "KMeansModel") {
method <- match.arg(method)
fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansCluster", object@model, method)
return(dataFrame(fittedResult))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
#' Get the summary of a k-means model
#'
#' Returns the summary of a k-means model produced by kmeans(),
#' similarly to R's summary().
#'
#' @param object a fitted k-means model
#' @return the model's coefficients, size and cluster
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' summary(model)
#' }
setMethod("summary", signature(object = "KMeansModel"),
function(object, ...) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
coefficients <- callJMethod(jobj, "coefficients")
cluster <- callJMethod(jobj, "cluster")
k <- callJMethod(jobj, "k")
size <- callJMethod(jobj, "size")
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
})

#' Make predictions from a k-means model
#'
#' Make predictions from a model produced by kmeans().
#'
#' @param object A fitted k-means model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted labels in a column named "prediction"
#' @rdname predict
#' @export
#' @examples
#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#' }
setMethod("predict", signature(object = "KMeansModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})

#' Fit a Bernoulli naive Bayes model
Expand Down
85 changes: 85 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.r

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.DataFrame

private[r] class KMeansWrapper private (
pipeline: PipelineModel) {

private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]

lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)

private lazy val attrs = AttributeGroup.fromStructField(
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))

lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)

lazy val k: Int = kMeansModel.getK

lazy val size: Array[Int] = kMeansModel.summary.size

lazy val cluster: DataFrame = kMeansModel.summary.cluster

def fitted(method: String): DataFrame = {
if (method == "centers") {
kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
} else if (method == "classes") {
kMeansModel.summary.cluster
} else {
throw new UnsupportedOperationException(
s"Method (centers or classes) required but $method found.")
}
}

def transform(dataset: DataFrame): DataFrame = {
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
}

}

private[r] object KMeansWrapper {

def fit(
data: DataFrame,
k: Double,
maxIter: Double,
initMode: String,
columns: Array[String]): KMeansWrapper = {

val assembler = new VectorAssembler()
.setInputCols(columns)
.setOutputCol("features")

val kMeans = new KMeans()
.setK(k.toInt)
.setMaxIter(maxIter.toInt)
.setInitMode(initMode)

val pipeline = new Pipeline()
.setStages(Array(assembler, kMeans))
.fit(data)

new KMeansWrapper(pipeline)
}
}
52 changes: 1 addition & 51 deletions mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame

Expand Down Expand Up @@ -52,22 +51,6 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}

def fitKMeans(
df: DataFrame,
initMode: String,
maxIter: Double,
k: Double,
columns: Array[String]): PipelineModel = {
val assembler = new VectorAssembler().setInputCols(columns)
val kMeans = new KMeans()
.setInitMode(initMode)
.setMaxIter(maxIter.toInt)
.setK(k.toInt)
.setFeaturesCol(assembler.getOutputCol)
val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
pipeline.fit(df)
}

def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel => {
Expand All @@ -89,8 +72,6 @@ private[r] object SparkRWrappers {
m.coefficients.toArray
}
}
case m: KMeansModel =>
m.clusterCenters.flatMap(_.toArray)
}
}

Expand All @@ -104,31 +85,6 @@ private[r] object SparkRWrappers {
}
}

def getKMeansModelSize(model: PipelineModel): Array[Int] = {
model.stages.last match {
case m: KMeansModel => Array(m.getK) ++ m.summary.size
case other => throw new UnsupportedOperationException(
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
}
}

def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
model.stages.last match {
case m: KMeansModel =>
if (method == "centers") {
// Drop the assembled vector for easy-print to R side.
m.summary.predictions.drop(m.summary.featuresCol)
} else if (method == "classes") {
m.summary.cluster
} else {
throw new UnsupportedOperationException(
s"Method (centers or classes) required but $method found.")
}
case other => throw new UnsupportedOperationException(
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
}
}

def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
Expand All @@ -147,10 +103,6 @@ private[r] object SparkRWrappers {
} else {
attrs.attributes.get.map(_.name.get)
}
case m: KMeansModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
attrs.attributes.get.map(_.name.get)
}
}

Expand All @@ -160,8 +112,6 @@ private[r] object SparkRWrappers {
"LinearRegressionModel"
case m: LogisticRegressionModel =>
"LogisticRegressionModel"
case m: KMeansModel =>
"KMeansModel"
}
}
}