@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
2020import scala .collection .JavaConverters ._
2121
2222import org .apache .spark .annotation .Experimental
23- import org .apache .spark .mllib .tree .impurity .Impurity
23+ import org .apache .spark .mllib .tree .impurity .{ Variance , Entropy , Gini , Impurity }
2424import org .apache .spark .mllib .tree .configuration .Algo ._
2525import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
2626
@@ -90,4 +90,33 @@ class Strategy (
9090 categoricalFeaturesInfo.asInstanceOf [java.util.Map [Int , Int ]].asScala.toMap)
9191 }
9292
93+ private [tree] def assertValid (): Unit = {
94+ algo match {
95+ case Classification =>
96+ require(numClassesForClassification >= 2 ,
97+ s " DecisionTree Strategy for Classification must have numClassesForClassification >= 2, " +
98+ s " but numClassesForClassification = $numClassesForClassification. " )
99+ require(Set (Gini , Entropy ).contains(impurity),
100+ s " DecisionTree Strategy given invalid impurity for Classification: $impurity. " +
101+ s " Valid settings: Gini, Entropy " )
102+ case Regression =>
103+ require(impurity == Variance ,
104+ s " DecisionTree Strategy given invalid impurity for Regression: $impurity. " +
105+ s " Valid settings: Variance " )
106+ case _ =>
107+ throw new IllegalArgumentException (
108+ s " DecisionTree Strategy given invalid algo parameter: $algo. " +
109+ s " Valid settings are: Classification, Regression. " )
110+ }
111+ require(maxDepth >= 0 , s " DecisionTree Strategy given invalid maxDepth parameter: $maxDepth. " +
112+ s " Valid values are integers >= 0. " )
113+ require(maxBins >= 2 , s " DecisionTree Strategy given invalid maxBins parameter: $maxBins. " +
114+ s " Valid values are integers >= 2. " )
115+ categoricalFeaturesInfo.foreach { case (feature, arity) =>
116+ require(arity >= 2 ,
117+ s " DecisionTree Strategy given invalid categoricalFeaturesInfo setting: " +
118+ s " feature $feature has $arity categories. The number of categories should be >= 2. " )
119+ }
120+ }
121+
93122}
0 commit comments