Skip to content

Commit 1022049

Browse files
felixcheungFelix Cheung
authored and
Felix Cheung
committed
[SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma, clarify glm family supported
## What changes were proposed in this pull request? backporting to 2.1, 2.0 and 1.6 ## How was this patch tested? unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #16532 from felixcheung/rgammabackport.
1 parent 230607d commit 1022049

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

R/pkg/R/mllib.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ predict_internal <- function(object, newData) {
184184
#' This can be a character string naming a family function, a family function or
185185
#' the result of a call to a family function. Refer R family at
186186
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
187+
#' Currently these families are supported: \code{binomial}, \code{gaussian},
188+
#' \code{Gamma}, and \code{poisson}.
187189
#' @param tol positive convergence tolerance of iterations.
188190
#' @param maxIter integer giving the maximal number of IRLS iterations.
189191
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
@@ -236,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
236238
weightCol <- ""
237239
}
238240

241+
# For known families, Gamma is upper-cased
239242
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
240-
"fit", formula, data@sdf, family$family, family$link,
243+
"fit", formula, data@sdf, tolower(family$family), family$link,
241244
tol, as.integer(maxIter), as.character(weightCol), regParam)
242245
new("GeneralizedLinearRegressionModel", jobj = jobj)
243246
})
@@ -252,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
252255
#' This can be a character string naming a family function, a family function or
253256
#' the result of a call to a family function. Refer R family at
254257
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
258+
#' Currently these families are supported: \code{binomial}, \code{gaussian},
259+
#' \code{Gamma}, and \code{poisson}.
255260
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
256261
#' weights as 1.0.
257262
#' @param epsilon positive convergence tolerance of iterations.

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ test_that("spark.glm and predict", {
7474
data = iris, family = poisson(link = identity)), iris))
7575
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
7676

77+
# Gamma family
78+
x <- runif(100, -1, 1)
79+
y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
80+
df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
81+
model <- glm(y ~ x, family = Gamma, df)
82+
out <- capture.output(print(summary(model)))
83+
expect_true(any(grepl("Dispersion parameter for gamma family", out)))
84+
7785
# Test stats::predict is working
7886
x <- rnorm(15)
7987
y <- x + rnorm(15)

0 commit comments

Comments
 (0)