Skip to content

Commit 3973403

Browse files
wangmiao1981Felix Cheung
authored and
Felix Cheung
committed
[SPARK-19456][SPARKR] Add LinearSVC R API
## What changes were proposed in this pull request? Linear SVM classifier is newly added into ML and python API has been added. This JIRA is to add R side API. Marked as WIP, as I am designing unit tests. ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16800 from wangmiao1981/svc.
1 parent 447b2b5 commit 3973403

File tree

7 files changed

+342
-4
lines changed

7 files changed

+342
-4
lines changed

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ exportMethods("glm",
6565
"spark.logit",
6666
"spark.randomForest",
6767
"spark.gbt",
68-
"spark.bisectingKmeans")
68+
"spark.bisectingKmeans",
69+
"spark.svmLinear")
6970

7071
# Job group lifecycle management methods
7172
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,10 @@ setGeneric("spark.randomForest",
14011401
#' @export
14021402
setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
14031403

1404+
#' @rdname spark.svmLinear
1405+
#' @export
1406+
setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") })
1407+
14041408
#' @rdname spark.lda
14051409
#' @export
14061410
setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })

R/pkg/R/mllib_classification.R

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
# mllib_regression.R: Provides methods for MLlib classification algorithms
1919
# (except for tree-based algorithms) integration
2020

21+
#' S4 class that represents an LinearSVCModel
22+
#'
23+
#' @param jobj a Java object reference to the backing Scala LinearSVCModel
24+
#' @export
25+
#' @note LinearSVCModel since 2.2.0
26+
setClass("LinearSVCModel", representation(jobj = "jobj"))
27+
2128
#' S4 class that represents an LogisticRegressionModel
2229
#'
2330
#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel
@@ -39,6 +46,131 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
3946
#' @note NaiveBayesModel since 2.0.0
4047
setClass("NaiveBayesModel", representation(jobj = "jobj"))
4148

49+
#' linear SVM Model
50+
#'
51+
#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package
52+
#' Users can print, make predictions on the produced model and save the model to the input path.
53+
#'
54+
#' @param data SparkDataFrame for training.
55+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
56+
#' operators are supported, including '~', '.', ':', '+', and '-'.
57+
#' @param regParam The regularization parameter.
58+
#' @param maxIter Maximum iteration number.
59+
#' @param tol Convergence tolerance of iterations.
60+
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients
61+
#' of models will be always returned on the original scale, so it will be transparent for
62+
#' users. Note that with/without standardization, the models should be always converged
63+
#' to the same solution when no regularization is applied.
64+
#' @param threshold The threshold in binary classification, in range [0, 1].
65+
#' @param weightCol The weight column name.
66+
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
67+
#' or the number of partitions are large, this param could be adjusted to a larger size.
68+
#' This is an expert parameter. Default value should be good for most cases.
69+
#' @param ... additional arguments passed to the method.
70+
#' @return \code{spark.svmLinear} returns a fitted linear SVM model.
71+
#' @rdname spark.svmLinear
72+
#' @aliases spark.svmLinear,SparkDataFrame,formula-method
73+
#' @name spark.svmLinear
74+
#' @export
75+
#' @examples
76+
#' \dontrun{
77+
#' sparkR.session()
78+
#' df <- createDataFrame(iris)
79+
#' training <- df[df$Species %in% c("versicolor", "virginica"), ]
80+
#' model <- spark.svmLinear(training, Species ~ ., regParam = 0.5)
81+
#' summary <- summary(model)
82+
#'
83+
#' # fitted values on training data
84+
#' fitted <- predict(model, training)
85+
#'
86+
#' # save fitted model to input path
87+
#' path <- "path/to/model"
88+
#' write.ml(model, path)
89+
#'
90+
#' # can also read back the saved model and predict
91+
#' # Note that summary deos not work on loaded model
92+
#' savedModel <- read.ml(path)
93+
#' summary(savedModel)
94+
#' }
95+
#' @note spark.svmLinear since 2.2.0
96+
setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
97+
function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
98+
threshold = 0.0, weightCol = NULL, aggregationDepth = 2) {
99+
formula <- paste(deparse(formula), collapse = "")
100+
101+
if (!is.null(weightCol) && weightCol == "") {
102+
weightCol <- NULL
103+
} else if (!is.null(weightCol)) {
104+
weightCol <- as.character(weightCol)
105+
}
106+
107+
jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
108+
data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
109+
as.numeric(tol), as.logical(standardization), as.numeric(threshold),
110+
weightCol, as.integer(aggregationDepth))
111+
new("LinearSVCModel", jobj = jobj)
112+
})
113+
114+
# Predicted values based on an LinearSVCModel model
115+
116+
#' @param newData a SparkDataFrame for testing.
117+
#' @return \code{predict} returns the predicted values based on an LinearSVCModel.
118+
#' @rdname spark.svmLinear
119+
#' @aliases predict,LinearSVCModel,SparkDataFrame-method
120+
#' @export
121+
#' @note predict(LinearSVCModel) since 2.2.0
122+
setMethod("predict", signature(object = "LinearSVCModel"),
123+
function(object, newData) {
124+
predict_internal(object, newData)
125+
})
126+
127+
# Get the summary of an LinearSVCModel
128+
129+
#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}.
130+
#' @return \code{summary} returns summary information of the fitted model, which is a list.
131+
#' The list includes \code{coefficients} (coefficients of the fitted model),
132+
#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes),
133+
#' \code{numFeatures} (number of features).
134+
#' @rdname spark.svmLinear
135+
#' @aliases summary,LinearSVCModel-method
136+
#' @export
137+
#' @note summary(LinearSVCModel) since 2.2.0
138+
setMethod("summary", signature(object = "LinearSVCModel"),
139+
function(object) {
140+
jobj <- object@jobj
141+
features <- callJMethod(jobj, "features")
142+
labels <- callJMethod(jobj, "labels")
143+
coefficients <- callJMethod(jobj, "coefficients")
144+
nCol <- length(coefficients) / length(features)
145+
coefficients <- matrix(unlist(coefficients), ncol = nCol)
146+
intercept <- callJMethod(jobj, "intercept")
147+
numClasses <- callJMethod(jobj, "numClasses")
148+
numFeatures <- callJMethod(jobj, "numFeatures")
149+
if (nCol == 1) {
150+
colnames(coefficients) <- c("Estimate")
151+
} else {
152+
colnames(coefficients) <- unlist(labels)
153+
}
154+
rownames(coefficients) <- unlist(features)
155+
list(coefficients = coefficients, intercept = intercept,
156+
numClasses = numClasses, numFeatures = numFeatures)
157+
})
158+
159+
# Save fitted LinearSVCModel to the input path
160+
161+
#' @param path The directory where the model is saved.
162+
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
163+
#' which means throw exception if the output path exists.
164+
#'
165+
#' @rdname spark.svmLinear
166+
#' @aliases write.ml,LinearSVCModel,character-method
167+
#' @export
168+
#' @note write.ml(LogisticRegression, character) since 2.2.0
169+
setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"),
170+
function(object, path, overwrite = FALSE) {
171+
write_internal(object, path, overwrite)
172+
})
173+
42174
#' Logistic Regression Model
43175
#'
44176
#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression

R/pkg/R/mllib_utils.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
3636
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
3737
#' @seealso \link{spark.kmeans},
38-
#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
39-
#' @seealso \link{spark.randomForest}, \link{spark.survreg},
38+
#' @seealso \link{spark.lda}, \link{spark.logit},
39+
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
40+
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear},
4041
#' @seealso \link{read.ml}
4142
NULL
4243

@@ -51,7 +52,7 @@ NULL
5152
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
5253
#' @seealso \link{spark.kmeans},
5354
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
54-
#' @seealso \link{spark.randomForest}, \link{spark.survreg}
55+
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
5556
NULL
5657

5758
write_internal <- function(object, path, overwrite = FALSE) {
@@ -115,6 +116,8 @@ read.ml <- function(path) {
115116
new("GBTClassificationModel", jobj = jobj)
116117
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.BisectingKMeansWrapper")) {
117118
new("BisectingKMeansModel", jobj = jobj)
119+
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) {
120+
new("LinearSVCModel", jobj = jobj)
118121
} else {
119122
stop("Unsupported model: ", jobj)
120123
}

R/pkg/inst/tests/testthat/test_mllib_classification.R

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,50 @@ absoluteSparkPath <- function(x) {
2727
file.path(sparkHome, x)
2828
}
2929

30+
test_that("spark.svmLinear", {
31+
df <- suppressWarnings(createDataFrame(iris))
32+
training <- df[df$Species %in% c("versicolor", "virginica"), ]
33+
model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10)
34+
summary <- summary(model)
35+
36+
# test summary coefficients return matrix type
37+
expect_true(class(summary$coefficients) == "matrix")
38+
expect_true(class(summary$coefficients[, 1]) == "numeric")
39+
40+
coefs <- summary$coefficients[, "Estimate"]
41+
expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085)
42+
expect_true(all(abs(coefs - expected_coefs) < 0.1))
43+
expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2)
44+
45+
# Test prediction with string label
46+
prediction <- predict(model, training)
47+
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character")
48+
expected <- c("versicolor", "versicolor", "versicolor", "virginica", "virginica",
49+
"virginica", "virginica", "virginica", "virginica", "virginica")
50+
expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected)
51+
52+
# Test model save and load
53+
modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp")
54+
write.ml(model, modelPath)
55+
expect_error(write.ml(model, modelPath))
56+
write.ml(model, modelPath, overwrite = TRUE)
57+
model2 <- read.ml(modelPath)
58+
coefs <- summary(model)$coefficients
59+
coefs2 <- summary(model2)$coefficients
60+
expect_equal(coefs, coefs2)
61+
unlink(modelPath)
62+
63+
# Test prediction with numeric label
64+
label <- c(0.0, 0.0, 0.0, 1.0, 1.0)
65+
feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
66+
data <- as.data.frame(cbind(label, feature))
67+
df <- createDataFrame(data)
68+
model <- spark.svmLinear(df, label ~ feature, regParam = 0.1)
69+
prediction <- collect(select(predict(model, df), "prediction"))
70+
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
71+
72+
})
73+
3074
test_that("spark.logit", {
3175
# R code to reproduce the result.
3276
# nolint start

0 commit comments

Comments
 (0)