Skip to content

[SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading #22986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
metadata.getAndSetParams(model)
// We ignore the impurity while loading models because in previous models it was wrongly
// set to gini (see SPARK-25959).
metadata.getAndSetParams(model, Some(List("impurity")))
model
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("1.4.0")
object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities

@Since("2.0.0")
override def load(path: String): DecisionTreeRegressor = super.load(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities

/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
Expand Down
19 changes: 10 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams {
private[ml] trait DecisionTreeClassifierParams
extends DecisionTreeParams with TreeClassifierParams

/**
* Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends Params {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can actually keep this trait and extend HasVarianceImpurity. It would be a no-op like a few other traits here. Although 'redundant', would it avoid any MiMa warnings or very slightly improve compatibility, even if this is private[spark]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this trait is still there (just some lines below). The MiMa warnings are there because the returned type of setImpurity returns this.type, which is now HasVarianceImpurity and the same for the setter...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just looking at this one more time ... hm, isn't this.type always referring to the current class's type? maybe it's some detail of the generated code and this is essentially a false positive. Or, honestly setImpurity can just be removed in this change too. It's deprecated for all of these classes. I don't mind doing that later too; would just avoid these MiMa exclusions too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is AFAIK. I think it is a false positive.

I wouldn't remove it in this change because since this is a bug, this PR can be backported to 2.4.1 too. I'd do it in a separate change. We would need anyway to add the MiMa rules for the missing methods, so it doesn't change much on the exclusions either...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but I don't know if we can back-port it because of the binary incompatibility, internal as it may be. I don't know. If it's not an issue then yes it can back port

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I am not sure how to verify that.


private[ml] trait HasVarianceImpurity extends Params {
/**
* Criterion used for information gain calculation (case-insensitive).
* Supported: "variance".
Expand All @@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params {
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
(value: String) =>
TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))

setDefault(impurity -> "variance")

Expand All @@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params {
}
}

private[ml] object TreeRegressorParams {
private[ml] object HasVarianceImpurity {
// These options should be lowercase.
final val supportedImpurities: Array[String] =
Array("variance").map(_.toLowerCase(Locale.ROOT))
}

/**
* Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends HasVarianceImpurity

private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
with TreeRegressorParams with HasVarianceCol {

Expand Down Expand Up @@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams {
Array("logistic").map(_.toLowerCase(Locale.ROOT))
}

private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity {

/**
* Loss function which GBT tries to minimize. (case-insensitive)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue is that GBTClassifier also logs wrong value of impurity param.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right. Currently we just saying that we are using gini, while we are using variance...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
model2: GBTClassificationModel): Unit = {
TreeTests.checkEqual(model, model2)
assert(model.numFeatures === model2.numFeatures)
assert(model.featureImportances == model2.featureImportances)
}

val gbt = new GBTClassifier()
Expand Down
11 changes: 11 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ object MimaExcludes {

// Exclude rules for 3.0.x
lazy val v30excludes = v24excludes ++ Seq(
// [SPARK-25959] GBTClassifier picks wrong impurity stats on loading
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"),

// [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),
Expand Down