Skip to content

Commit 81303f7

Browse files
committed
[SPARK-19806][ML][PYSPARK] PySpark GeneralizedLinearRegression supports tweedie distribution.
## What changes were proposed in this pull request? PySpark ```GeneralizedLinearRegression``` supports tweedie distribution. ## How was this patch tested? Add unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #17146 from yanboliang/spark-19806.
1 parent 1fa5886 commit 81303f7

File tree

3 files changed

+77
-12
lines changed

3 files changed

+77
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
6666
/**
6767
* Param for the power in the variance function of the Tweedie distribution which provides
6868
* the relationship between the variance and mean of the distribution.
69-
* Only applicable for the Tweedie family.
69+
* Only applicable to the Tweedie family.
7070
* (see <a href="https://en.wikipedia.org/wiki/Tweedie_distribution">
7171
* Tweedie Distribution (Wikipedia)</a>)
7272
* Supported values: 0 and [1, Inf).
@@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
7979
final val variancePower: DoubleParam = new DoubleParam(this, "variancePower",
8080
"The power in the variance function of the Tweedie distribution which characterizes " +
8181
"the relationship between the variance and mean of the distribution. " +
82-
"Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).",
82+
"Only applicable to the Tweedie family. Supported values: 0 and [1, Inf).",
8383
(x: Double) => x >= 1.0 || x == 0.0)
8484

8585
/** @group getParam */
@@ -106,7 +106,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
106106
def getLink: String = $(link)
107107

108108
/**
109-
* Param for the index in the power link function. Only applicable for the Tweedie family.
109+
* Param for the index in the power link function. Only applicable to the Tweedie family.
110110
* Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt
111111
* link, respectively.
112112
* When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod"
@@ -116,7 +116,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
116116
*/
117117
@Since("2.2.0")
118118
final val linkPower: DoubleParam = new DoubleParam(this, "linkPower",
119-
"The index in the power link function. Only applicable for the Tweedie family.")
119+
"The index in the power link function. Only applicable to the Tweedie family.")
120120

121121
/** @group getParam */
122122
@Since("2.2.0")

python/pyspark/ml/regression.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12941294
12951295
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
12961296
predictor (link function) and a description of the error distribution (family). It supports
1297-
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
1298-
is listed below. The first link function of each family is the default one.
1297+
"gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
1298+
each family is listed below. The first link function of each family is the default one.
12991299
13001300
* "gaussian" -> "identity", "log", "inverse"
13011301
@@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
13051305
13061306
* "gamma" -> "inverse", "identity", "log"
13071307
1308+
* "tweedie" -> power link function specified through "linkPower". \
1309+
The default link power in the tweedie family is 1 - variancePower.
1310+
13081311
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
13091312
13101313
>>> from pyspark.ml.linalg import Vectors
@@ -1344,40 +1347,54 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
13441347

13451348
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
13461349
"the error distribution to be used in the model. Supported options: " +
1347-
"gaussian (default), binomial, poisson and gamma.",
1350+
"gaussian (default), binomial, poisson, gamma and tweedie.",
13481351
typeConverter=TypeConverters.toString)
13491352
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
13501353
"relationship between the linear predictor and the mean of the distribution " +
13511354
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
13521355
"and sqrt.", typeConverter=TypeConverters.toString)
13531356
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
13541357
"predictor) column name", typeConverter=TypeConverters.toString)
1358+
variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " +
1359+
"of the Tweedie distribution which characterizes the relationship " +
1360+
"between the variance and mean of the distribution. Only applicable " +
1361+
"for the Tweedie family. Supported values: 0 and [1, Inf).",
1362+
typeConverter=TypeConverters.toFloat)
1363+
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
1364+
"Only applicable to the Tweedie family.",
1365+
typeConverter=TypeConverters.toFloat)
13551366

13561367
@keyword_only
13571368
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
13581369
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1359-
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None):
1370+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1371+
variancePower=0.0, linkPower=None):
13601372
"""
13611373
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13621374
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1363-
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
1375+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1376+
variancePower=0.0, linkPower=None)
13641377
"""
13651378
super(GeneralizedLinearRegression, self).__init__()
13661379
self._java_obj = self._new_java_obj(
13671380
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
1368-
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
1381+
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
1382+
variancePower=0.0)
13691383
kwargs = self._input_kwargs
1384+
13701385
self.setParams(**kwargs)
13711386

13721387
@keyword_only
13731388
@since("2.0.0")
13741389
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
13751390
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1376-
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None):
1391+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
1392+
variancePower=0.0, linkPower=None):
13771393
"""
13781394
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13791395
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1380-
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
1396+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1397+
variancePower=0.0, linkPower=None)
13811398
Sets params for generalized linear regression.
13821399
"""
13831400
kwargs = self._input_kwargs
@@ -1428,6 +1445,34 @@ def getLink(self):
14281445
"""
14291446
return self.getOrDefault(self.link)
14301447

1448+
@since("2.2.0")
1449+
def setVariancePower(self, value):
1450+
"""
1451+
Sets the value of :py:attr:`variancePower`.
1452+
"""
1453+
return self._set(variancePower=value)
1454+
1455+
@since("2.2.0")
1456+
def getVariancePower(self):
1457+
"""
1458+
Gets the value of variancePower or its default value.
1459+
"""
1460+
return self.getOrDefault(self.variancePower)
1461+
1462+
@since("2.2.0")
1463+
def setLinkPower(self, value):
1464+
"""
1465+
Sets the value of :py:attr:`linkPower`.
1466+
"""
1467+
return self._set(linkPower=value)
1468+
1469+
@since("2.2.0")
1470+
def getLinkPower(self):
1471+
"""
1472+
Gets the value of linkPower or its default value.
1473+
"""
1474+
return self.getOrDefault(self.linkPower)
1475+
14311476

14321477
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
14331478
JavaMLReadable):

python/pyspark/ml/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,26 @@ def test_apply_binary_term_freqs(self):
12231223
": expected " + str(expected[i]) + ", got " + str(features[i]))
12241224

12251225

1226+
class GeneralizedLinearRegressionTest(SparkSessionTestCase):
1227+
1228+
def test_tweedie_distribution(self):
1229+
1230+
df = self.spark.createDataFrame(
1231+
[(1.0, Vectors.dense(0.0, 0.0)),
1232+
(1.0, Vectors.dense(1.0, 2.0)),
1233+
(2.0, Vectors.dense(0.0, 0.0)),
1234+
(2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
1235+
1236+
glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
1237+
model = glr.fit(df)
1238+
self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4))
1239+
self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
1240+
1241+
model2 = glr.setLinkPower(-1.0).fit(df)
1242+
self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
1243+
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
1244+
1245+
12261246
class ALSTest(SparkSessionTestCase):
12271247

12281248
def test_storage_levels(self):

0 commit comments

Comments
 (0)