Skip to content

Commit e00cac9

Browse files
mgaido91srowen
authored andcommitted
[SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading
## What changes were proposed in this pull request? Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its `impurity` param by default contains the value `gini`: it is not even modifiable by the user and it differs from the actual impurity used, which is `variance`. This issue does not limit to a wrong value returned for it if the user queries by `getImpurity`, but it also affect the load of a saved model, as its `impurityStats` are created as `gini` (since this is the value stored for the model impurity) which leads to wrong `featureImportances` in model loaded from saved ones. The PR changes the `impurity` param used to one which allows only the value `variance`. ## How was this patch tested? modified UT Closes #22986 from mgaido91/SPARK-25959. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent e557c53 commit e00cac9

File tree

6 files changed

+27
-12
lines changed

6 files changed

+27
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
427427
s" trees based on metadata but found ${trees.length} trees.")
428428
val model = new GBTClassificationModel(metadata.uid,
429429
trees, treeWeights, numFeatures)
430-
metadata.getAndSetParams(model)
430+
// We ignore the impurity while loading models because in previous models it was wrongly
431+
// set to gini (see SPARK-25959).
432+
metadata.getAndSetParams(model, Some(List("impurity")))
431433
model
432434
}
433435
}

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
145145
@Since("1.4.0")
146146
object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
147147
/** Accessor for supported impurities: variance */
148-
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
148+
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
149149

150150
@Since("2.0.0")
151151
override def load(path: String): DecisionTreeRegressor = super.load(path)

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
146146
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
147147
/** Accessor for supported impurity settings: variance */
148148
@Since("1.4.0")
149-
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
149+
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
150150

151151
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
152152
@Since("1.4.0")

mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams {
258258
private[ml] trait DecisionTreeClassifierParams
259259
extends DecisionTreeParams with TreeClassifierParams
260260

261-
/**
262-
* Parameters for Decision Tree-based regression algorithms.
263-
*/
264-
private[ml] trait TreeRegressorParams extends Params {
265-
261+
private[ml] trait HasVarianceImpurity extends Params {
266262
/**
267263
* Criterion used for information gain calculation (case-insensitive).
268264
* Supported: "variance".
@@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params {
271267
*/
272268
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
273269
" information gain calculation (case-insensitive). Supported options:" +
274-
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
270+
s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
275271
(value: String) =>
276-
TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
272+
HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
277273

278274
setDefault(impurity -> "variance")
279275

@@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params {
299295
}
300296
}
301297

302-
private[ml] object TreeRegressorParams {
298+
private[ml] object HasVarianceImpurity {
303299
// These options should be lowercase.
304300
final val supportedImpurities: Array[String] =
305301
Array("variance").map(_.toLowerCase(Locale.ROOT))
306302
}
307303

304+
/**
305+
* Parameters for Decision Tree-based regression algorithms.
306+
*/
307+
private[ml] trait TreeRegressorParams extends HasVarianceImpurity
308+
308309
private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
309310
with TreeRegressorParams with HasVarianceCol {
310311

@@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams {
538539
Array("logistic").map(_.toLowerCase(Locale.ROOT))
539540
}
540541

541-
private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
542+
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity {
542543

543544
/**
544545
* Loss function which GBT tries to minimize. (case-insensitive)

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
448448
model2: GBTClassificationModel): Unit = {
449449
TreeTests.checkEqual(model, model2)
450450
assert(model.numFeatures === model2.numFeatures)
451+
assert(model.featureImportances == model2.featureImportances)
451452
}
452453

453454
val gbt = new GBTClassifier()

project/MimaExcludes.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ object MimaExcludes {
3636

3737
// Exclude rules for 3.0.x
3838
lazy val v30excludes = v24excludes ++ Seq(
39+
// [SPARK-25959] GBTClassifier picks wrong impurity stats on loading
40+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"),
41+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
42+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
43+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
44+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
45+
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
46+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"),
47+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"),
48+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"),
49+
3950
// [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
4051
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
4152
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),

0 commit comments

Comments
 (0)