Skip to content

Commit 8daf10e

Browse files
committed
[SPARK-19155][ML] MLlib GeneralizedLinearRegression family and link should case insensitive
## What changes were proposed in this pull request? MLlib ```GeneralizedLinearRegression``` ```family``` and ```link``` should be case insensitive. This is consistent with some other MLlib params such as [```featureSubsetStrategy```](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala#L415). ## How was this patch tested? Update corresponding tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #16516 from yanboliang/spark-19133. (cherry picked from commit 3dcad9f) Signed-off-by: Yanbo Liang <ybliang8@gmail.com>
1 parent 6f0ad57 commit 8daf10e

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
5757
final val family: Param[String] = new Param(this, "family",
5858
"The name of family which is a description of the error distribution to be used in the " +
5959
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
60-
ParamValidators.inArray[String](supportedFamilyNames.toArray))
60+
(value: String) => supportedFamilyNames.contains(value.toLowerCase))
6161

6262
/** @group getParam */
6363
@Since("2.0.0")
@@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
7474
final val link: Param[String] = new Param(this, "link", "The name of link function " +
7575
"which provides the relationship between the linear predictor and the mean of the " +
7676
s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
77-
ParamValidators.inArray[String](supportedLinkNames.toArray))
77+
(value: String) => supportedLinkNames.contains(value.toLowerCase))
7878

7979
/** @group getParam */
8080
@Since("2.0.0")
@@ -405,7 +405,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
405405
* @param name family name: "gaussian", "binomial", "poisson" or "gamma".
406406
*/
407407
def fromName(name: String): Family = {
408-
name match {
408+
name.toLowerCase match {
409409
case Gaussian.name => Gaussian
410410
case Binomial.name => Binomial
411411
case Poisson.name => Poisson
@@ -609,7 +609,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
609609
* "inverse", "probit", "cloglog" or "sqrt".
610610
*/
611611
def fromName(name: String): Link = {
612-
name match {
612+
name.toLowerCase match {
613613
case Identity.name => Identity
614614
case Logit.name => Logit
615615
case Log.name => Log

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ class GeneralizedLinearRegressionSuite
553553
for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
554554
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
555555
for (fitIntercept <- Seq(false, true)) {
556-
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
556+
val trainer = new GeneralizedLinearRegression().setFamily("Gamma").setLink(link)
557557
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
558558
val model = trainer.fit(dataset)
559559
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
@@ -989,7 +989,7 @@ class GeneralizedLinearRegressionSuite
989989
-0.6344390 0.3172195 0.2114797 -0.1586097
990990
*/
991991
val trainer = new GeneralizedLinearRegression()
992-
.setFamily("gamma")
992+
.setFamily("Gamma")
993993
.setWeightCol("weight")
994994

995995
val model = trainer.fit(datasetWithWeight)

0 commit comments

Comments
 (0)