-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-19806][ML][PySpark] PySpark GeneralizedLinearRegression supports tweedie distribution. #17146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19806][ML][PySpark] PySpark GeneralizedLinearRegression supports tweedie distribution. #17146
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha | |
|
||
Fit a Generalized Linear Model specified by giving a symbolic description of the linear | ||
predictor (link function) and a description of the error distribution (family). It supports | ||
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family | ||
is listed below. The first link function of each family is the default one. | ||
"gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for | ||
each family is listed below. The first link function of each family is the default one. | ||
|
||
* "gaussian" -> "identity", "log", "inverse" | ||
|
||
|
@@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha | |
|
||
* "gamma" -> "inverse", "identity", "log" | ||
|
||
* "tweedie" -> power link function specified through "linkPower". \ | ||
The default link power in the tweedie family is 1 - variancePower. | ||
|
||
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_ | ||
|
||
>>> from pyspark.ml.linalg import Vectors | ||
|
@@ -1344,40 +1347,54 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha | |
|
||
family = Param(Params._dummy(), "family", "The name of family which is a description of " + | ||
"the error distribution to be used in the model. Supported options: " + | ||
"gaussian (default), binomial, poisson and gamma.", | ||
"gaussian (default), binomial, poisson, gamma and tweedie.", | ||
typeConverter=TypeConverters.toString) | ||
link = Param(Params._dummy(), "link", "The name of link function which provides the " + | ||
"relationship between the linear predictor and the mean of the distribution " + | ||
"function. Supported options: identity, log, inverse, logit, probit, cloglog " + | ||
"and sqrt.", typeConverter=TypeConverters.toString) | ||
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + | ||
"predictor) column name", typeConverter=TypeConverters.toString) | ||
variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " + | ||
"of the Tweedie distribution which characterizes the relationship " + | ||
"between the variance and mean of the distribution. Only applicable " + | ||
"for the Tweedie family. Supported values: 0 and [1, Inf).", | ||
typeConverter=TypeConverters.toFloat) | ||
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + | ||
"Only applicable to the Tweedie family.", | ||
typeConverter=TypeConverters.toFloat) | ||
|
||
@keyword_only | ||
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", | ||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there check to make sure link=None when family="Tweedie"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, actually we allow users to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, it sounds like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, there is no default value for |
||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, | ||
variancePower=0.0, linkPower=None): | ||
""" | ||
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ | ||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ | ||
variancePower=0.0, linkPower=None) | ||
""" | ||
super(GeneralizedLinearRegression, self).__init__() | ||
self._java_obj = self._new_java_obj( | ||
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) | ||
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") | ||
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", | ||
variancePower=0.0) | ||
kwargs = self._input_kwargs | ||
|
||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
@since("2.0.0") | ||
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", | ||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, | ||
variancePower=0.0, linkPower=None): | ||
""" | ||
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ | ||
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) | ||
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ | ||
variancePower=0.0, linkPower=None) | ||
Sets params for generalized linear regression. | ||
""" | ||
kwargs = self._input_kwargs | ||
|
@@ -1428,6 +1445,34 @@ def getLink(self): | |
""" | ||
return self.getOrDefault(self.link) | ||
|
||
@since("2.2.0") | ||
def setVariancePower(self, value): | ||
""" | ||
Sets the value of :py:attr:`variancePower`. | ||
""" | ||
return self._set(variancePower=value) | ||
|
||
@since("2.2.0") | ||
def getVariancePower(self): | ||
""" | ||
Gets the value of variancePower or its default value. | ||
""" | ||
return self.getOrDefault(self.variancePower) | ||
|
||
@since("2.2.0") | ||
def setLinkPower(self, value): | ||
""" | ||
Sets the value of :py:attr:`linkPower`. | ||
""" | ||
return self._set(linkPower=value) | ||
|
||
@since("2.2.0") | ||
def getLinkPower(self): | ||
""" | ||
Gets the value of linkPower or its default value. | ||
""" | ||
return self.getOrDefault(self.linkPower) | ||
|
||
|
||
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, | ||
JavaMLReadable): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1223,6 +1223,26 @@ def test_apply_binary_term_freqs(self): | |
": expected " + str(expected[i]) + ", got " + str(features[i])) | ||
|
||
|
||
class GeneralizedLinearRegressionTest(SparkSessionTestCase): | ||
|
||
def test_tweedie_distribution(self): | ||
|
||
df = self.spark.createDataFrame( | ||
[(1.0, Vectors.dense(0.0, 0.0)), | ||
(1.0, Vectors.dense(1.0, 2.0)), | ||
(2.0, Vectors.dense(0.0, 0.0)), | ||
(2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) | ||
|
||
glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) | ||
model = glr.fit(df) | ||
self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm curious: where did the expected values come from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are produced by R under the same input. |
||
self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) | ||
|
||
model2 = glr.setLinkPower(-1.0).fit(df) | ||
self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) | ||
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) | ||
|
||
|
||
class ALSTest(SparkSessionTestCase): | ||
|
||
def test_storage_levels(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens when both variancePower ad linkPower is set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will produce a model according to the specified
variancePower
andlinkPower
. The doc here is to explain the value oflinkPower
if users don't specify.