Skip to content

Commit c6e2dfc

Browse files
committed
Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py
1 parent 0278a11 commit c6e2dfc

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ object DecisionTreeRunner {
5555
maxDepth: Int = 5,
5656
impurity: ImpurityType = Gini,
5757
maxBins: Int = 32,
58+
minInstancesPerNode: Int = 1,
59+
minInfoGain: Double = 0.0,
5860
fracTest: Double = 0.2)
5961

6062
def main(args: Array[String]) {
@@ -75,6 +77,13 @@ object DecisionTreeRunner {
7577
opt[Int]("maxBins")
7678
.text(s"max number of bins, default: ${defaultParams.maxBins}")
7779
.action((x, c) => c.copy(maxBins = x))
80+
opt[Int]("minInstancesPerNode")
81+
.text(s"min number of instances required at child nodes to create the parent split," +
82+
s" default: ${defaultParams.minInstancesPerNode}")
83+
.action((x, c) => c.copy(minInstancesPerNode = x))
84+
opt[Double]("minInfoGain")
85+
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
86+
.action((x, c) => c.copy(minInfoGain = x))
7887
opt[Double]("fracTest")
7988
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
8089
.action((x, c) => c.copy(fracTest = x))
@@ -179,7 +188,9 @@ object DecisionTreeRunner {
179188
impurity = impurityCalculator,
180189
maxDepth = params.maxDepth,
181190
maxBins = params.maxBins,
182-
numClassesForClassification = numClasses)
191+
numClassesForClassification = numClasses,
192+
minInstancesPerNode = params.minInstancesPerNode,
193+
minInfoGain = params.minInfoGain)
183194
val model = DecisionTree.train(training, strategy)
184195

185196
println(model)

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
303303
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
304304
impurityStr: String,
305305
maxDepth: Int,
306-
maxBins: Int): DecisionTreeModel = {
306+
maxBins: Int,
307+
minInstancesPerNode: Int,
308+
minInfoGain: Double): DecisionTreeModel = {
307309

308310
val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
309311

@@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
316318
maxDepth = maxDepth,
317319
numClassesForClassification = numClasses,
318320
maxBins = maxBins,
319-
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
321+
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
322+
minInstancesPerNode = minInstancesPerNode,
323+
minInfoGain = minInfoGain)
320324

321325
DecisionTree.train(data, strategy)
322326
}

python/pyspark/mllib/tree.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class DecisionTree(object):
138138

139139
@staticmethod
140140
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
141-
impurity="gini", maxDepth=5, maxBins=32):
141+
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
142+
minInfoGain=0.0):
142143
"""
143144
Train a DecisionTreeModel for classification.
144145
@@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
154155
E.g., depth 0 means 1 leaf node.
155156
Depth 1 means 1 internal node + 2 leaf nodes.
156157
:param maxBins: Number of bins used for finding splits at each node.
158+
:param minInstancesPerNode: Min number of instances required at child nodes to create
159+
the parent split
160+
:param minInfoGain: Min info gain required to create a split
157161
:return: DecisionTreeModel
158162
"""
159163
sc = data.context
@@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
164168
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
165169
dataBytes._jrdd, "classification",
166170
numClasses, categoricalFeaturesInfoJMap,
167-
impurity, maxDepth, maxBins)
171+
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
168172
dataBytes.unpersist()
169173
return DecisionTreeModel(sc, model)
170174

171175
@staticmethod
172176
def trainRegressor(data, categoricalFeaturesInfo,
173-
impurity="variance", maxDepth=5, maxBins=32):
177+
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
178+
minInfoGain=0.0):
174179
"""
175180
Train a DecisionTreeModel for regression.
176181
@@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo,
185190
E.g., depth 0 means 1 leaf node.
186191
Depth 1 means 1 internal node + 2 leaf nodes.
187192
:param maxBins: Number of bins used for finding splits at each node.
193+
:param minInstancesPerNode: Min number of instances required at child nodes to create
194+
the parent split
195+
:param minInfoGain: Min info gain required to create a split
188196
:return: DecisionTreeModel
189197
"""
190198
sc = data.context
@@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo,
195203
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
196204
dataBytes._jrdd, "regression",
197205
0, categoricalFeaturesInfoJMap,
198-
impurity, maxDepth, maxBins)
206+
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
199207
dataBytes.unpersist()
200208
return DecisionTreeModel(sc, model)
201209

0 commit comments

Comments
 (0)