@@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
1294
1294
1295
1295
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
1296
1296
predictor (link function) and a description of the error distribution (family). It supports
1297
- "gaussian", "binomial", "poisson" and "gamma " as family. Valid link functions for each family
1298
- is listed below. The first link function of each family is the default one.
1297
+ "gaussian", "binomial", "poisson", "gamma" and "tweedie " as family. Valid link functions for
1298
+ each family is listed below. The first link function of each family is the default one.
1299
1299
1300
1300
* "gaussian" -> "identity", "log", "inverse"
1301
1301
@@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
1305
1305
1306
1306
* "gamma" -> "inverse", "identity", "log"
1307
1307
1308
+ * "tweedie" -> power link function specified through "linkPower". \
1309
+ The default link power in the tweedie family is 1 - variancePower.
1310
+
1308
1311
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
1309
1312
1310
1313
>>> from pyspark.ml.linalg import Vectors
@@ -1344,40 +1347,54 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
1344
1347
1345
1348
family = Param (Params ._dummy (), "family" , "The name of family which is a description of " +
1346
1349
"the error distribution to be used in the model. Supported options: " +
1347
- "gaussian (default), binomial, poisson and gamma ." ,
1350
+ "gaussian (default), binomial, poisson, gamma and tweedie ." ,
1348
1351
typeConverter = TypeConverters .toString )
1349
1352
link = Param (Params ._dummy (), "link" , "The name of link function which provides the " +
1350
1353
"relationship between the linear predictor and the mean of the distribution " +
1351
1354
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
1352
1355
"and sqrt." , typeConverter = TypeConverters .toString )
1353
1356
linkPredictionCol = Param (Params ._dummy (), "linkPredictionCol" , "link prediction (linear " +
1354
1357
"predictor) column name" , typeConverter = TypeConverters .toString )
1358
+ variancePower = Param (Params ._dummy (), "variancePower" , "The power in the variance function " +
1359
+ "of the Tweedie distribution which characterizes the relationship " +
1360
+ "between the variance and mean of the distribution. Only applicable " +
1361
+ "for the Tweedie family. Supported values: 0 and [1, Inf)." ,
1362
+ typeConverter = TypeConverters .toFloat )
1363
+ linkPower = Param (Params ._dummy (), "linkPower" , "The index in the power link function. " +
1364
+ "Only applicable to the Tweedie family." ,
1365
+ typeConverter = TypeConverters .toFloat )
1355
1366
1356
1367
@keyword_only
1357
1368
def __init__ (self , labelCol = "label" , featuresCol = "features" , predictionCol = "prediction" ,
1358
1369
family = "gaussian" , link = None , fitIntercept = True , maxIter = 25 , tol = 1e-6 ,
1359
- regParam = 0.0 , weightCol = None , solver = "irls" , linkPredictionCol = None ):
1370
+ regParam = 0.0 , weightCol = None , solver = "irls" , linkPredictionCol = None ,
1371
+ variancePower = 0.0 , linkPower = None ):
1360
1372
"""
1361
1373
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
1362
1374
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1363
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
1375
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1376
+ variancePower=0.0, linkPower=None)
1364
1377
"""
1365
1378
super (GeneralizedLinearRegression , self ).__init__ ()
1366
1379
self ._java_obj = self ._new_java_obj (
1367
1380
"org.apache.spark.ml.regression.GeneralizedLinearRegression" , self .uid )
1368
- self ._setDefault (family = "gaussian" , maxIter = 25 , tol = 1e-6 , regParam = 0.0 , solver = "irls" )
1381
+ self ._setDefault (family = "gaussian" , maxIter = 25 , tol = 1e-6 , regParam = 0.0 , solver = "irls" ,
1382
+ variancePower = 0.0 )
1369
1383
kwargs = self ._input_kwargs
1384
+
1370
1385
self .setParams (** kwargs )
1371
1386
1372
1387
@keyword_only
1373
1388
@since ("2.0.0" )
1374
1389
def setParams (self , labelCol = "label" , featuresCol = "features" , predictionCol = "prediction" ,
1375
1390
family = "gaussian" , link = None , fitIntercept = True , maxIter = 25 , tol = 1e-6 ,
1376
- regParam = 0.0 , weightCol = None , solver = "irls" , linkPredictionCol = None ):
1391
+ regParam = 0.0 , weightCol = None , solver = "irls" , linkPredictionCol = None ,
1392
+ variancePower = 0.0 , linkPower = None ):
1377
1393
"""
1378
1394
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
1379
1395
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1380
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
1396
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
1397
+ variancePower=0.0, linkPower=None)
1381
1398
Sets params for generalized linear regression.
1382
1399
"""
1383
1400
kwargs = self ._input_kwargs
@@ -1428,6 +1445,34 @@ def getLink(self):
1428
1445
"""
1429
1446
return self .getOrDefault (self .link )
1430
1447
1448
+ @since ("2.2.0" )
1449
+ def setVariancePower (self , value ):
1450
+ """
1451
+ Sets the value of :py:attr:`variancePower`.
1452
+ """
1453
+ return self ._set (variancePower = value )
1454
+
1455
+ @since ("2.2.0" )
1456
+ def getVariancePower (self ):
1457
+ """
1458
+ Gets the value of variancePower or its default value.
1459
+ """
1460
+ return self .getOrDefault (self .variancePower )
1461
+
1462
+ @since ("2.2.0" )
1463
+ def setLinkPower (self , value ):
1464
+ """
1465
+ Sets the value of :py:attr:`linkPower`.
1466
+ """
1467
+ return self ._set (linkPower = value )
1468
+
1469
+ @since ("2.2.0" )
1470
+ def getLinkPower (self ):
1471
+ """
1472
+ Gets the value of linkPower or its default value.
1473
+ """
1474
+ return self .getOrDefault (self .linkPower )
1475
+
1431
1476
1432
1477
class GeneralizedLinearRegressionModel (JavaModel , JavaPredictionModel , JavaMLWritable ,
1433
1478
JavaMLReadable ):
0 commit comments