Skip to content

Commit c9de184

Browse files
actuaryzhangcmonkey
authored andcommitted
[SPARK-19155][ML] Make family case insensitive in GLM
## What changes were proposed in this pull request? This is a supplement to PR apache#16516 which did not make the value from `getFamily` case insensitive. Current tests of poisson/binomial glm with weight fail when specifying 'Poisson' or 'Binomial', because the calculation of `dispersion` and `pValue` checks the value of family retrieved from `getFamily` ``` model.getFamily == Binomial.name || model.getFamily == Poisson.name ``` ## How was this patch tested? Update existing tests for 'Poisson' and 'Binomial'. yanboliang felixcheung imatiach-msft Author: actuaryzhang <actuaryzhang10@gmail.com> Closes apache#16675 from actuaryzhang/family.
1 parent 104ea17 commit c9de184

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
10441044
*/
10451045
@Since("2.0.0")
10461046
lazy val dispersion: Double = if (
1047-
model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
1047+
model.getFamily.toLowerCase == Binomial.name ||
1048+
model.getFamily.toLowerCase == Poisson.name) {
10481049
1.0
10491050
} else {
10501051
val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0)
@@ -1147,7 +1148,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] (
11471148
@Since("2.0.0")
11481149
lazy val pValues: Array[Double] = {
11491150
if (isNormalSolver) {
1150-
if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) {
1151+
if (model.getFamily.toLowerCase == Binomial.name ||
1152+
model.getFamily.toLowerCase == Poisson.name) {
11511153
tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) }
11521154
} else {
11531155
tValues.map { x =>

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ class GeneralizedLinearRegressionSuite
758758
0.028480 0.069123 0.935495 -0.049613
759759
*/
760760
val trainer = new GeneralizedLinearRegression()
761-
.setFamily("binomial")
761+
.setFamily("Binomial")
762762
.setWeightCol("weight")
763763
.setFitIntercept(false)
764764

@@ -875,7 +875,7 @@ class GeneralizedLinearRegressionSuite
875875
-0.4378554 0.2189277 0.1459518 -0.1094638
876876
*/
877877
val trainer = new GeneralizedLinearRegression()
878-
.setFamily("poisson")
878+
.setFamily("Poisson")
879879
.setWeightCol("weight")
880880
.setFitIntercept(true)
881881

0 commit comments

Comments
 (0)