Skip to content

Commit 115eeb3

Browse files
jkbradleymengxr
authored andcommitted
[mllib] DecisionTree: treeAggregate + Python example bug fix
Small DecisionTree updates: * Changed main DecisionTree aggregate to treeAggregate. * Fixed bug in python example decision_tree_runner.py with missing argument (since categoricalFeaturesInfo is no longer an optional argument for trainClassifier). * Fixed same bug in python doc tests, and added tree.py to doc tests. CC: mengxr Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #2015 from jkbradley/dt-opt2 and squashes the following commits: b5114fa [Joseph K. Bradley] Fixed python tree.py doc test (extra newline) 8e4665d [Joseph K. Bradley] Added tree.py to python doc tests. Fixed bug from missing categoricalFeaturesInfo argument. b7b2922 [Joseph K. Bradley] Fixed bug in python example decision_tree_runner.py with missing argument. Changed main DecisionTree aggregate to treeAggregate. 85bbc1f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 66d076f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
1 parent 6201b27 commit 115eeb3

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

examples/src/main/python/mllib/decision_tree_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def usage():
124124
(reindexedData, origToNewLabels) = reindexClassLabels(points)
125125

126126
# Train a classifier.
127-
model = DecisionTree.trainClassifier(reindexedData, numClasses=2)
127+
categoricalFeaturesInfo={} # no categorical features
128+
model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
129+
categoricalFeaturesInfo=categoricalFeaturesInfo)
128130
# Print learned tree and stats.
129131
print "Trained DecisionTree for classification:"
130132
print " Model numNodes: %d\n" % model.numNodes()

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
2222
import org.apache.spark.annotation.Experimental
2323
import org.apache.spark.api.java.JavaRDD
2424
import org.apache.spark.Logging
25+
import org.apache.spark.mllib.rdd.RDDFunctions._
2526
import org.apache.spark.mllib.regression.LabeledPoint
2627
import org.apache.spark.mllib.tree.configuration.Strategy
2728
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -826,7 +827,7 @@ object DecisionTree extends Serializable with Logging {
826827
// Calculate bin aggregates.
827828
timer.start("aggregation")
828829
val binAggregates = {
829-
input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
830+
input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
830831
}
831832
timer.stop("aggregation")
832833
logDebug("binAggregates.length = " + binAggregates.length)

python/pyspark/mllib/tree.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class DecisionTree(object):
8888
It will probably be modified for Spark v1.2.
8989
9090
Example usage:
91-
>>> from numpy import array, ndarray
91+
>>> from numpy import array
92+
>>> import sys
9293
>>> from pyspark.mllib.regression import LabeledPoint
9394
>>> from pyspark.mllib.tree import DecisionTree
9495
>>> from pyspark.mllib.linalg import SparseVector
@@ -99,15 +100,15 @@ class DecisionTree(object):
99100
... LabeledPoint(1.0, [2.0]),
100101
... LabeledPoint(1.0, [3.0])
101102
... ]
102-
>>>
103-
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
104-
>>> print(model)
103+
>>> categoricalFeaturesInfo = {} # no categorical features
104+
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
105+
... categoricalFeaturesInfo=categoricalFeaturesInfo)
106+
>>> sys.stdout.write(model)
105107
DecisionTreeModel classifier
106108
If (feature 0 <= 0.5)
107109
Predict: 0.0
108110
Else (feature 0 > 0.5)
109111
Predict: 1.0
110-
111112
>>> model.predict(array([1.0])) > 0
112113
True
113114
>>> model.predict(array([0.0])) == 0
@@ -119,7 +120,8 @@ class DecisionTree(object):
119120
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
120121
... ]
121122
>>>
122-
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
123+
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
124+
... categoricalFeaturesInfo=categoricalFeaturesInfo)
123125
>>> model.predict(array([0.0, 1.0])) == 1
124126
True
125127
>>> model.predict(array([0.0, 0.0])) == 0

python/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py"
7979
run_test "pyspark/mllib/recommendation.py"
8080
run_test "pyspark/mllib/regression.py"
8181
run_test "pyspark/mllib/tests.py"
82+
run_test "pyspark/mllib/tree.py"
8283
run_test "pyspark/mllib/util.py"
8384

8485
if [[ $FAILED == 0 ]]; then

0 commit comments

Comments
 (0)