Skip to content

Commit 7a61f7b

Browse files
committed
Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters
1 parent e537b33 commit 7a61f7b

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ import org.apache.spark.util.random.XORShiftRandom
4040
@Experimental
4141
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
4242

43+
strategy.assertValid()
44+
4345
/**
4446
* Method to train a decision tree model over an RDD
4547
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -1368,10 +1370,14 @@ object DecisionTree extends Serializable with Logging {
13681370

13691371

13701372
/*
1371-
* Ensure #bins is always greater than the categories. For multiclass classification,
1372-
* #bins should be greater than 2^(maxCategories - 1) - 1.
1373+
* Ensure numBins is always greater than the categories. For multiclass classification,
1374+
* numBins should be greater than 2^(maxCategories - 1) - 1.
13731375
* It's a limitation of the current implementation but a reasonable trade-off since features
13741376
* with large number of categories get favored over continuous features.
1377+
*
1378+
* This needs to be checked here instead of in Strategy since numBins can be determined
1379+
* by the number of training examples.
1380+
* TODO: Allow this case, where we simply will know nothing about some categories.
13751381
*/
13761382
if (strategy.categoricalFeaturesInfo.size > 0) {
13771383
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
2020
import scala.collection.JavaConverters._
2121

2222
import org.apache.spark.annotation.Experimental
23-
import org.apache.spark.mllib.tree.impurity.Impurity
23+
import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
2424
import org.apache.spark.mllib.tree.configuration.Algo._
2525
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2626

@@ -90,4 +90,33 @@ class Strategy (
9090
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
9191
}
9292

93+
private[tree] def assertValid(): Unit = {
94+
algo match {
95+
case Classification =>
96+
require(numClassesForClassification >= 2,
97+
s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," +
98+
s" but numClassesForClassification = $numClassesForClassification.")
99+
require(Set(Gini, Entropy).contains(impurity),
100+
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
101+
s" Valid settings: Gini, Entropy")
102+
case Regression =>
103+
require(impurity == Variance,
104+
s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
105+
s" Valid settings: Variance")
106+
case _ =>
107+
throw new IllegalArgumentException(
108+
s"DecisionTree Strategy given invalid algo parameter: $algo." +
109+
s" Valid settings are: Classification, Regression.")
110+
}
111+
require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." +
112+
s" Valid values are integers >= 0.")
113+
require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
114+
s" Valid values are integers >= 2.")
115+
categoricalFeaturesInfo.foreach { case (feature, arity) =>
116+
require(arity >= 2,
117+
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
118+
s" feature $feature has $arity categories. The number of categories should be >= 2.")
119+
}
120+
}
121+
93122
}

0 commit comments

Comments
 (0)