17
17
18
18
package org .apache .spark .mllib .tree
19
19
20
- import java .util .Calendar
21
20
22
21
import scala .collection .JavaConverters ._
23
22
@@ -29,45 +28,12 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
29
28
import org .apache .spark .mllib .tree .configuration .Algo ._
30
29
import org .apache .spark .mllib .tree .configuration .FeatureType ._
31
30
import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
32
- import org .apache .spark .mllib .tree .impl .TreePoint
31
+ import org .apache .spark .mllib .tree .impl .{ TimeTracker , TreePoint }
33
32
import org .apache .spark .mllib .tree .impurity .{Impurities , Gini , Entropy , Impurity }
34
33
import org .apache .spark .mllib .tree .model ._
35
34
import org .apache .spark .rdd .RDD
36
35
import org .apache .spark .util .random .XORShiftRandom
37
36
38
- class TimeTracker {
39
-
40
- var tmpTime : Long = Calendar .getInstance().getTimeInMillis
41
-
42
- def reset (): Unit = {
43
- tmpTime = Calendar .getInstance().getTimeInMillis
44
- }
45
-
46
- def elapsed (): Long = {
47
- Calendar .getInstance().getTimeInMillis - tmpTime
48
- }
49
-
50
- var initTime : Long = 0 // Data retag and cache
51
- var findSplitsBinsTime : Long = 0
52
- var extractNodeInfoTime : Long = 0
53
- var extractInfoForLowerLevelsTime : Long = 0
54
- var findBestSplitsTime : Long = 0
55
- var findBinsForLevelTime : Long = 0
56
- var binAggregatesTime : Long = 0
57
- var chooseSplitsTime : Long = 0
58
-
59
- override def toString : String = {
60
- s " DecisionTree timing \n " +
61
- s " initTime: $initTime\n " +
62
- s " findSplitsBinsTime: $findSplitsBinsTime\n " +
63
- s " extractNodeInfoTime: $extractNodeInfoTime\n " +
64
- s " extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n " +
65
- s " findBestSplitsTime: $findBestSplitsTime\n " +
66
- s " findBinsForLevelTime: $findBinsForLevelTime\n " +
67
- s " binAggregatesTime: $binAggregatesTime\n " +
68
- s " chooseSplitsTime: $chooseSplitsTime\n "
69
- }
70
- }
71
37
72
38
/**
73
39
* :: Experimental ::
@@ -90,26 +56,26 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
90
56
def train (input : RDD [LabeledPoint ]): DecisionTreeModel = {
91
57
92
58
val timer = new TimeTracker ()
93
- timer.reset()
94
59
60
+ timer.start(" total" )
61
+
62
+ timer.start(" init" )
95
63
// Cache input RDD for speedup during multiple passes.
96
64
val retaggedInput = input.retag(classOf [LabeledPoint ])
97
65
logDebug(" algo = " + strategy.algo)
98
-
99
- timer.initTime += timer.elapsed()
100
- timer.reset()
66
+ timer.stop(" init" )
101
67
102
68
// Find the splits and the corresponding bins (interval between the splits) using a sample
103
69
// of the input data.
70
+ timer.start(" findSplitsBins" )
104
71
val (splits, bins) = DecisionTree .findSplitsBins(retaggedInput, strategy)
105
72
val numBins = bins(0 ).length
73
+ timer.stop(" findSplitsBins" )
106
74
logDebug(" numBins = " + numBins)
107
75
108
- timer.findSplitsBinsTime += timer.elapsed()
109
-
110
- timer.reset()
76
+ timer.start(" init" )
111
77
val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins)
112
- timer.initTime += timer.elapsed( )
78
+ timer.stop( " init " )
113
79
114
80
// depth of the decision tree
115
81
val maxDepth = strategy.maxDepth
@@ -166,21 +132,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
166
132
167
133
168
134
// Find best split for all nodes at a level.
169
- timer.reset( )
135
+ timer.start( " findBestSplits " )
170
136
val splitsStatsForLevel = DecisionTree .findBestSplits(treeInput, parentImpurities,
171
137
strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
172
- timer.findBestSplitsTime += timer.elapsed( )
138
+ timer.stop( " findBestSplits " )
173
139
174
140
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
175
- timer.reset( )
141
+ timer.start( " extractNodeInfo " )
176
142
// Extract info for nodes at the current level.
177
143
extractNodeInfo(nodeSplitStats, level, index, nodes)
178
- timer.extractNodeInfoTime += timer.elapsed( )
179
- timer.reset( )
144
+ timer.stop( " extractNodeInfo " )
145
+ timer.start( " extractInfoForLowerLevels " )
180
146
// Extract info for nodes at the next lower level.
181
147
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
182
148
filters)
183
- timer.extractInfoForLowerLevelsTime += timer.elapsed( )
149
+ timer.stop( " extractInfoForLowerLevels " )
184
150
logDebug(" final best split = " + nodeSplitStats._1)
185
151
}
186
152
require(math.pow(2 , level) == splitsStatsForLevel.length)
@@ -194,8 +160,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
194
160
}
195
161
}
196
162
197
- println(timer)
198
-
199
163
logDebug(" #####################################" )
200
164
logDebug(" Extracting tree model" )
201
165
logDebug(" #####################################" )
@@ -205,6 +169,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
205
169
// Build the full tree using the node info calculated in the level-wise best split calculations.
206
170
topNode.build(nodes)
207
171
172
+ timer.stop(" total" )
173
+
174
+ // println(timer) // Print internal timing info.
175
+
208
176
new DecisionTreeModel (topNode, strategy.algo)
209
177
}
210
178
@@ -252,7 +220,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
252
220
// noting the parents filters for the child nodes
253
221
val childFilter = new Filter (nodeSplitStats._1, if (i == 0 ) - 1 else 1 )
254
222
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1 ) / 2 )
255
- // println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
256
223
for (filter <- filters(nodeIndex)) {
257
224
logDebug(" Filter = " + filter)
258
225
}
@@ -491,7 +458,6 @@ object DecisionTree extends Serializable with Logging {
491
458
maxLevelForSingleGroup : Int ,
492
459
timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats )] = {
493
460
// split into groups to avoid memory overflow during aggregation
494
- // println(s"findBestSplits: level = $level")
495
461
if (level > maxLevelForSingleGroup) {
496
462
// When information for all nodes at a given level cannot be stored in memory,
497
463
// the nodes are divided into multiple groups at each level with the number of groups
@@ -681,7 +647,6 @@ object DecisionTree extends Serializable with Logging {
681
647
val parentFilters = findParentFilters(nodeIndex)
682
648
// Find out whether the sample qualifies for the particular node.
683
649
val sampleValid = isSampleValid(parentFilters, treePoint)
684
- // println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}")
685
650
val shift = 1 + numFeatures * nodeIndex
686
651
if (! sampleValid) {
687
652
// Mark one bin as -1 is sufficient.
@@ -699,12 +664,12 @@ object DecisionTree extends Serializable with Logging {
699
664
arr
700
665
}
701
666
702
- timer.reset( )
667
+ timer.start( " findBinsForLevel " )
703
668
704
669
// Find feature bins for all nodes at a level.
705
670
val binMappedRDD = input.map(x => findBinsForLevel(x))
706
671
707
- timer.findBinsForLevelTime += timer.elapsed( )
672
+ timer.stop( " findBinsForLevel " )
708
673
709
674
/**
710
675
* Increment aggregate in location for (node, feature, bin, label).
@@ -752,7 +717,6 @@ object DecisionTree extends Serializable with Logging {
752
717
label : Double ,
753
718
agg : Array [Double ],
754
719
rightChildShift : Int ): Unit = {
755
- // println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
756
720
// Find the bin index for this feature.
757
721
val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
758
722
val featureValue = arr(arrIndex).toInt
@@ -830,10 +794,6 @@ object DecisionTree extends Serializable with Logging {
830
794
// Check whether the instance was valid for this nodeIndex.
831
795
val validSignalIndex = 1 + numFeatures * nodeIndex
832
796
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
833
- if (level == 1 ) {
834
- val nodeFilterIndex = math.pow(2 , level).toInt - 1 + nodeIndex + groupShift
835
- // println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}")
836
- }
837
797
if (isSampleValidForNode) {
838
798
// actual class label
839
799
val label = arr(0 )
@@ -954,39 +914,15 @@ object DecisionTree extends Serializable with Logging {
954
914
combinedAggregate
955
915
}
956
916
957
- timer.reset()
958
917
959
918
// Calculate bin aggregates.
919
+ timer.start(" binAggregates" )
960
920
val binAggregates = {
961
921
binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
962
922
}
923
+ timer.stop(" binAggregates" )
963
924
logDebug(" binAggregates.length = " + binAggregates.length)
964
925
965
- timer.binAggregatesTime += timer.elapsed()
966
- // 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
967
- // (left/right, node, feature, bin, label)
968
- /*
969
- println(s"binAggregates:")
970
- for (i <- Range(0,2)) {
971
- for (n <- Range(0,numNodes)) {
972
- for (f <- Range(0,numFeatures)) {
973
- for (b <- Range(0,4)) {
974
- for (c <- Range(0,numClasses)) {
975
- val idx = i * numClasses * numBins * numFeatures * numNodes +
976
- n * numClasses * numBins * numFeatures +
977
- f * numBins * numFeatures +
978
- b * numFeatures +
979
- c
980
- if (binAggregates(idx) != 0) {
981
- println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
982
- }
983
- }
984
- }
985
- }
986
- }
987
- }
988
- */
989
-
990
926
/**
991
927
* Calculates the information gain for all splits based upon left/right split aggregates.
992
928
* @param leftNodeAgg left node aggregates
@@ -1027,7 +963,6 @@ object DecisionTree extends Serializable with Logging {
1027
963
val totalCount = leftTotalCount + rightTotalCount
1028
964
if (totalCount == 0 ) {
1029
965
// Return arbitrary prediction.
1030
- // println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
1031
966
return new InformationGainStats (0 , topImpurity, topImpurity, topImpurity, 0 )
1032
967
}
1033
968
@@ -1054,9 +989,6 @@ object DecisionTree extends Serializable with Logging {
1054
989
}
1055
990
1056
991
val predict = indexOfLargestArrayElement(leftRightCounts)
1057
- if (predict == 0 && featureIndex == 0 && splitIndex == 0 ) {
1058
- // println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}")
1059
- }
1060
992
val prob = leftRightCounts(predict) / totalCount
1061
993
1062
994
val leftImpurity = if (leftTotalCount == 0 ) {
@@ -1209,7 +1141,6 @@ object DecisionTree extends Serializable with Logging {
1209
1141
}
1210
1142
splitIndex += 1
1211
1143
}
1212
- // println(s"found Agg: $TMPDEBUG")
1213
1144
}
1214
1145
1215
1146
def findAggForRegression (
@@ -1369,7 +1300,6 @@ object DecisionTree extends Serializable with Logging {
1369
1300
bestGainStats = gainStats
1370
1301
bestFeatureIndex = featureIndex
1371
1302
bestSplitIndex = splitIndex
1372
- // println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
1373
1303
}
1374
1304
splitIndex += 1
1375
1305
}
@@ -1414,7 +1344,7 @@ object DecisionTree extends Serializable with Logging {
1414
1344
}
1415
1345
}
1416
1346
1417
- timer.reset( )
1347
+ timer.start( " chooseSplits " )
1418
1348
1419
1349
// Calculate best splits for all nodes at a given level
1420
1350
val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
@@ -1427,10 +1357,9 @@ object DecisionTree extends Serializable with Logging {
1427
1357
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
1428
1358
logDebug(" parent node impurity = " + parentNodeImpurity)
1429
1359
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
1430
- // println(s"bestSplits(node:$node): ${bestSplits(node)}")
1431
1360
node += 1
1432
1361
}
1433
- timer.chooseSplitsTime += timer.elapsed( )
1362
+ timer.stop( " chooseSplits " )
1434
1363
1435
1364
bestSplits
1436
1365
}
0 commit comments