Skip to content

Commit

Permalink
[SPARK-3751] [mllib] DecisionTree: example update + print options
Browse files Browse the repository at this point in the history
DecisionTreeRunner functionality additions:
* Allow user to pass in a test dataset
* Do not print full model if the model is too large.

As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info.  Proposed updates:
* toString: prints model summary
* toDebugString: prints full model (named after RDD.toDebugString)

Similar update to Python API:
* __repr__() now prints a model summary
* toDebugString() now prints the full model

CC: mengxr  chouqin manishamde codedeft  Small update (whomever can take a look).  Thanks!

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes apache#2604 from jkbradley/dtrunner-update and squashes the following commits:

b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before
07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model
1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing.
22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update
e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset.
  • Loading branch information
jkbradley authored and mengxr committed Oct 1, 2014
1 parent eb43043 commit 7bf6cc9
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ object DecisionTreeRunner {

case class Params(
input: String = null,
testInput: String = "",
dataFormat: String = "libsvm",
algo: Algo = Classification,
maxDepth: Int = 5,
Expand Down Expand Up @@ -98,13 +99,18 @@ object DecisionTreeRunner {
s"default: ${defaultParams.featureSubsetStrategy}")
.action((x, c) => c.copy(featureSubsetStrategy = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
.action((x, c) => c.copy(testInput = x))
opt[String]("<dataFormat>")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(dataFormat = x))
arg[String]("<input>")
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
.text("input path to labeled examples")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
Expand Down Expand Up @@ -141,7 +147,7 @@ object DecisionTreeRunner {
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
}
// For classification, re-index classes if needed.
val (examples, numClasses) = params.algo match {
val (examples, classIndexMap, numClasses) = params.algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
Expand Down Expand Up @@ -170,16 +176,40 @@ object DecisionTreeRunner {
val frac = classCounts(c) / numExamples.toDouble
println(s"$c\t$frac\t${classCounts(c)}")
}
(examples, numClasses)
(examples, classIndexMap, numClasses)
}
case Regression =>
(origExamples, 0)
(origExamples, null, 0)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}

// Split into training, test.
val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
// Create training, test sets.
val splits = if (params.testInput != "") {
// Load testInput.
val origTestExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
}
params.algo match {
case Classification => {
// classCounts: class --> # examples in class
val testExamples = {
if (classIndexMap.isEmpty) {
origTestExamples
} else {
origTestExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
}
}
Array(examples, testExamples)
}
case Regression =>
Array(examples, origTestExamples)
}
} else {
// Split input into training, test.
examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
}
val training = splits(0).cache()
val test = splits(1).cache()
val numTraining = training.count()
Expand All @@ -206,47 +236,62 @@ object DecisionTreeRunner {
minInfoGain = params.minInfoGain)
if (params.numTrees == 1) {
val model = DecisionTree.train(training, strategy)
println(model)
if (model.numNodes < 20) {
println(model.toDebugString) // Print full model.
} else {
println(model) // Print model summary.
}
if (params.algo == Classification) {
val accuracy =
val trainAccuracy =
new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
.precision
println(s"Train accuracy = $trainAccuracy")
val testAccuracy =
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
println(s"Test accuracy = $accuracy")
println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
val mse = meanSquaredError(model, test)
println(s"Test mean squared error = $mse")
val trainMSE = meanSquaredError(model, training)
println(s"Train mean squared error = $trainMSE")
val testMSE = meanSquaredError(model, test)
println(s"Test mean squared error = $testMSE")
}
} else {
val randomSeed = Utils.random.nextInt()
if (params.algo == Classification) {
val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
println(model)
val accuracy =
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
println(model) // Print model summary.
}
val trainAccuracy =
new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
.precision
println(s"Train accuracy = $trainAccuracy")
val testAccuracy =
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
println(s"Test accuracy = $accuracy")
println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
println(model)
val mse = meanSquaredError(model, test)
println(s"Test mean squared error = $mse")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
println(model) // Print model summary.
}
val trainMSE = meanSquaredError(model, training)
println(s"Train mean squared error = $trainMSE")
val testMSE = meanSquaredError(model, test)
println(s"Test mean squared error = $testMSE")
}
}

sc.stop()
}

/**
* Calculates the classifier accuracy.
*/
private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}

/**
* Calculates the mean squared error for regression.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
}

/**
* Print full model.
* Print a summary of the model.
*/
override def toString: String = algo match {
case Classification =>
s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2)
s"DecisionTreeModel classifier of depth $depth with $numNodes nodes"
case Regression =>
s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2)
s"DecisionTreeModel regressor of depth $depth with $numNodes nodes"
case _ => throw new IllegalArgumentException(
s"DecisionTreeModel given unknown algo parameter: $algo.")
}

/**
* Print the full model to a string.
*/
def toDebugString: String = {
val header = toString + "\n"
header + topNode.subtreeToString(2)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
def numTrees: Int = trees.size

/**
* Print full model.
* Get total number of nodes, summed over all trees in the forest.
*/
override def toString: String = {
val header = algo match {
case Classification =>
s"RandomForestModel classifier with $numTrees trees\n"
case Regression =>
s"RandomForestModel regressor with $numTrees trees\n"
case _ => throw new IllegalArgumentException(
s"RandomForestModel given unknown algo parameter: $algo.")
}
def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum

/**
* Print a summary of the model.
*/
override def toString: String = algo match {
case Classification =>
s"RandomForestModel classifier with $numTrees trees"
case Regression =>
s"RandomForestModel regressor with $numTrees trees"
case _ => throw new IllegalArgumentException(
s"RandomForestModel given unknown algo parameter: $algo.")
}

/**
* Print the full model to a string.
*/
def toDebugString: String = {
val header = toString + "\n"
header + trees.zipWithIndex.map { case (tree, treeIndex) =>
s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
}.fold("")(_ + _)
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,13 @@ def depth(self):
return self._java_model.depth()

def __repr__(self):
""" Print summary of model. """
return self._java_model.toString()

def toDebugString(self):
""" Print full model. """
return self._java_model.toDebugString()


class DecisionTree(object):

Expand Down Expand Up @@ -135,7 +140,6 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
>>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
>>>
>>> data = [
... LabeledPoint(0.0, [0.0]),
Expand All @@ -145,7 +149,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
... ]
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
>>> print model, # it already has newline
DecisionTreeModel classifier
DecisionTreeModel classifier of depth 1 with 3 nodes
>>> print model.toDebugString(), # it already has newline
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)
Expand Down

0 comments on commit 7bf6cc9

Please sign in to comment.