Skip to content

[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGain params to example and Python API #2349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ object DecisionTreeRunner {
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
fracTest: Double = 0.2)

def main(args: Array[String]) {
Expand All @@ -75,6 +77,13 @@ object DecisionTreeRunner {
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
opt[Int]("minInstancesPerNode")
.text(s"min number of instances required at child nodes to create the parent split," +
s" default: ${defaultParams.minInstancesPerNode}")
.action((x, c) => c.copy(minInstancesPerNode = x))
opt[Double]("minInfoGain")
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
.action((x, c) => c.copy(minInfoGain = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
Expand Down Expand Up @@ -179,7 +188,9 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = numClasses)
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
val model = DecisionTree.train(training, strategy)

println(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
maxBins: Int,
minInstancesPerNode: Int,
minInfoGain: Double): DecisionTreeModel = {

val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)

Expand All @@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)

DecisionTree.train(data, strategy)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
var groupIndex = 0
var doneTraining = true
while (groupIndex < numGroups) {
val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
topNode, splits, bins, timer, numGroups, groupIndex)
doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
Expand Down Expand Up @@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)

require(predict.isDefined, "must calculate predict for each node")
assert(predict.isDefined, "must calculate predict for each node")

(bestSplit, bestSplitStats, predict.get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class Strategy (
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")

val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
private[tree] class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{
val prob: Double = 0.0) extends Serializable {

override def toString = {
"predict = %f, prob = %f".format(predict, prob)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}

test("don't choose split that doesn't satisfy min instance per node requirements") {
// if a split doesn't satisfy min instances per node requirements,
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "don't" is typo?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a typo. But I figured that, if people are munging logs from tests, quote characters might be troublesome to deal with.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds reasonable, thanks.

// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
Expand Down
16 changes: 12 additions & 4 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class DecisionTree(object):

@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32):
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.

Expand All @@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:param minInstancesPerNode: Min number of instances required at child nodes to create
the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
Expand All @@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32):
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.

Expand All @@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:param minInstancesPerNode: Min number of instances required at child nodes to create
the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
Expand All @@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo,
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

Expand Down