Skip to content

Commit 8464a6e

Browse files
committed
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable
1 parent a87e08f commit 8464a6e

File tree

4 files changed

+101
-116
lines changed

4 files changed

+101
-116
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 25 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import java.util.Calendar
2120

2221
import scala.collection.JavaConverters._
2322

@@ -29,45 +28,12 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2928
import org.apache.spark.mllib.tree.configuration.Algo._
3029
import org.apache.spark.mllib.tree.configuration.FeatureType._
3130
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
32-
import org.apache.spark.mllib.tree.impl.TreePoint
31+
import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
3332
import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
3433
import org.apache.spark.mllib.tree.model._
3534
import org.apache.spark.rdd.RDD
3635
import org.apache.spark.util.random.XORShiftRandom
3736

38-
class TimeTracker {
39-
40-
var tmpTime: Long = Calendar.getInstance().getTimeInMillis
41-
42-
def reset(): Unit = {
43-
tmpTime = Calendar.getInstance().getTimeInMillis
44-
}
45-
46-
def elapsed(): Long = {
47-
Calendar.getInstance().getTimeInMillis - tmpTime
48-
}
49-
50-
var initTime: Long = 0 // Data retag and cache
51-
var findSplitsBinsTime: Long = 0
52-
var extractNodeInfoTime: Long = 0
53-
var extractInfoForLowerLevelsTime: Long = 0
54-
var findBestSplitsTime: Long = 0
55-
var findBinsForLevelTime: Long = 0
56-
var binAggregatesTime: Long = 0
57-
var chooseSplitsTime: Long = 0
58-
59-
override def toString: String = {
60-
s"DecisionTree timing\n" +
61-
s"initTime: $initTime\n" +
62-
s"findSplitsBinsTime: $findSplitsBinsTime\n" +
63-
s"extractNodeInfoTime: $extractNodeInfoTime\n" +
64-
s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" +
65-
s"findBestSplitsTime: $findBestSplitsTime\n" +
66-
s"findBinsForLevelTime: $findBinsForLevelTime\n" +
67-
s"binAggregatesTime: $binAggregatesTime\n" +
68-
s"chooseSplitsTime: $chooseSplitsTime\n"
69-
}
70-
}
7137

7238
/**
7339
* :: Experimental ::
@@ -90,26 +56,26 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
9056
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
9157

9258
val timer = new TimeTracker()
93-
timer.reset()
9459

60+
timer.start("total")
61+
62+
timer.start("init")
9563
// Cache input RDD for speedup during multiple passes.
9664
val retaggedInput = input.retag(classOf[LabeledPoint])
9765
logDebug("algo = " + strategy.algo)
98-
99-
timer.initTime += timer.elapsed()
100-
timer.reset()
66+
timer.stop("init")
10167

10268
// Find the splits and the corresponding bins (interval between the splits) using a sample
10369
// of the input data.
70+
timer.start("findSplitsBins")
10471
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
10572
val numBins = bins(0).length
73+
timer.stop("findSplitsBins")
10674
logDebug("numBins = " + numBins)
10775

108-
timer.findSplitsBinsTime += timer.elapsed()
109-
110-
timer.reset()
76+
timer.start("init")
11177
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
112-
timer.initTime += timer.elapsed()
78+
timer.stop("init")
11379

11480
// depth of the decision tree
11581
val maxDepth = strategy.maxDepth
@@ -166,21 +132,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
166132

167133

168134
// Find best split for all nodes at a level.
169-
timer.reset()
135+
timer.start("findBestSplits")
170136
val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
171137
strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
172-
timer.findBestSplitsTime += timer.elapsed()
138+
timer.stop("findBestSplits")
173139

174140
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
175-
timer.reset()
141+
timer.start("extractNodeInfo")
176142
// Extract info for nodes at the current level.
177143
extractNodeInfo(nodeSplitStats, level, index, nodes)
178-
timer.extractNodeInfoTime += timer.elapsed()
179-
timer.reset()
144+
timer.stop("extractNodeInfo")
145+
timer.start("extractInfoForLowerLevels")
180146
// Extract info for nodes at the next lower level.
181147
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
182148
filters)
183-
timer.extractInfoForLowerLevelsTime += timer.elapsed()
149+
timer.stop("extractInfoForLowerLevels")
184150
logDebug("final best split = " + nodeSplitStats._1)
185151
}
186152
require(math.pow(2, level) == splitsStatsForLevel.length)
@@ -194,8 +160,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
194160
}
195161
}
196162

197-
println(timer)
198-
199163
logDebug("#####################################")
200164
logDebug("Extracting tree model")
201165
logDebug("#####################################")
@@ -205,6 +169,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
205169
// Build the full tree using the node info calculated in the level-wise best split calculations.
206170
topNode.build(nodes)
207171

172+
timer.stop("total")
173+
174+
//println(timer) // Print internal timing info.
175+
208176
new DecisionTreeModel(topNode, strategy.algo)
209177
}
210178

@@ -252,7 +220,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
252220
// noting the parents filters for the child nodes
253221
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
254222
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
255-
//println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
256223
for (filter <- filters(nodeIndex)) {
257224
logDebug("Filter = " + filter)
258225
}
@@ -491,7 +458,6 @@ object DecisionTree extends Serializable with Logging {
491458
maxLevelForSingleGroup: Int,
492459
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
493460
// split into groups to avoid memory overflow during aggregation
494-
//println(s"findBestSplits: level = $level")
495461
if (level > maxLevelForSingleGroup) {
496462
// When information for all nodes at a given level cannot be stored in memory,
497463
// the nodes are divided into multiple groups at each level with the number of groups
@@ -681,7 +647,6 @@ object DecisionTree extends Serializable with Logging {
681647
val parentFilters = findParentFilters(nodeIndex)
682648
// Find out whether the sample qualifies for the particular node.
683649
val sampleValid = isSampleValid(parentFilters, treePoint)
684-
//println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}")
685650
val shift = 1 + numFeatures * nodeIndex
686651
if (!sampleValid) {
687652
// Mark one bin as -1 is sufficient.
@@ -699,12 +664,12 @@ object DecisionTree extends Serializable with Logging {
699664
arr
700665
}
701666

702-
timer.reset()
667+
timer.start("findBinsForLevel")
703668

704669
// Find feature bins for all nodes at a level.
705670
val binMappedRDD = input.map(x => findBinsForLevel(x))
706671

707-
timer.findBinsForLevelTime += timer.elapsed()
672+
timer.stop("findBinsForLevel")
708673

709674
/**
710675
* Increment aggregate in location for (node, feature, bin, label).
@@ -752,7 +717,6 @@ object DecisionTree extends Serializable with Logging {
752717
label: Double,
753718
agg: Array[Double],
754719
rightChildShift: Int): Unit = {
755-
//println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
756720
// Find the bin index for this feature.
757721
val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
758722
val featureValue = arr(arrIndex).toInt
@@ -830,10 +794,6 @@ object DecisionTree extends Serializable with Logging {
830794
// Check whether the instance was valid for this nodeIndex.
831795
val validSignalIndex = 1 + numFeatures * nodeIndex
832796
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
833-
if (level == 1) {
834-
val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
835-
//println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}")
836-
}
837797
if (isSampleValidForNode) {
838798
// actual class label
839799
val label = arr(0)
@@ -954,39 +914,15 @@ object DecisionTree extends Serializable with Logging {
954914
combinedAggregate
955915
}
956916

957-
timer.reset()
958917

959918
// Calculate bin aggregates.
919+
timer.start("binAggregates")
960920
val binAggregates = {
961921
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
962922
}
923+
timer.stop("binAggregates")
963924
logDebug("binAggregates.length = " + binAggregates.length)
964925

965-
timer.binAggregatesTime += timer.elapsed()
966-
//2 * numClasses * numBins * numFeatures * numNodes for unordered features.
967-
// (left/right, node, feature, bin, label)
968-
/*
969-
println(s"binAggregates:")
970-
for (i <- Range(0,2)) {
971-
for (n <- Range(0,numNodes)) {
972-
for (f <- Range(0,numFeatures)) {
973-
for (b <- Range(0,4)) {
974-
for (c <- Range(0,numClasses)) {
975-
val idx = i * numClasses * numBins * numFeatures * numNodes +
976-
n * numClasses * numBins * numFeatures +
977-
f * numBins * numFeatures +
978-
b * numFeatures +
979-
c
980-
if (binAggregates(idx) != 0) {
981-
println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
982-
}
983-
}
984-
}
985-
}
986-
}
987-
}
988-
*/
989-
990926
/**
991927
* Calculates the information gain for all splits based upon left/right split aggregates.
992928
* @param leftNodeAgg left node aggregates
@@ -1027,7 +963,6 @@ object DecisionTree extends Serializable with Logging {
1027963
val totalCount = leftTotalCount + rightTotalCount
1028964
if (totalCount == 0) {
1029965
// Return arbitrary prediction.
1030-
//println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
1031966
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
1032967
}
1033968

@@ -1054,9 +989,6 @@ object DecisionTree extends Serializable with Logging {
1054989
}
1055990

1056991
val predict = indexOfLargestArrayElement(leftRightCounts)
1057-
if (predict == 0 && featureIndex == 0 && splitIndex == 0) {
1058-
//println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}")
1059-
}
1060992
val prob = leftRightCounts(predict) / totalCount
1061993

1062994
val leftImpurity = if (leftTotalCount == 0) {
@@ -1209,7 +1141,6 @@ object DecisionTree extends Serializable with Logging {
12091141
}
12101142
splitIndex += 1
12111143
}
1212-
//println(s"found Agg: $TMPDEBUG")
12131144
}
12141145

12151146
def findAggForRegression(
@@ -1369,7 +1300,6 @@ object DecisionTree extends Serializable with Logging {
13691300
bestGainStats = gainStats
13701301
bestFeatureIndex = featureIndex
13711302
bestSplitIndex = splitIndex
1372-
//println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
13731303
}
13741304
splitIndex += 1
13751305
}
@@ -1414,7 +1344,7 @@ object DecisionTree extends Serializable with Logging {
14141344
}
14151345
}
14161346

1417-
timer.reset()
1347+
timer.start("chooseSplits")
14181348

14191349
// Calculate best splits for all nodes at a given level
14201350
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
@@ -1427,10 +1357,9 @@ object DecisionTree extends Serializable with Logging {
14271357
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
14281358
logDebug("parent node impurity = " + parentNodeImpurity)
14291359
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
1430-
//println(s"bestSplits(node:$node): ${bestSplits(node)}")
14311360
node += 1
14321361
}
1433-
timer.chooseSplitsTime += timer.elapsed()
1362+
timer.stop("chooseSplits")
14341363

14351364
bestSplits
14361365
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.impl
19+
20+
import scala.collection.mutable.{HashMap => MutableHashMap}
21+
22+
import org.apache.spark.annotation.Experimental
23+
24+
/**
25+
* Time tracker implementation which holds labeled timers.
26+
*/
27+
@Experimental
28+
private[tree]
29+
class TimeTracker extends Serializable {
30+
31+
private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
32+
33+
private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
34+
35+
/**
36+
* Starts a new timer, or re-starts a stopped timer.
37+
*/
38+
def start(timerLabel: String): Unit = {
39+
val tmpTime = System.nanoTime()
40+
if (starts.contains(timerLabel)) {
41+
throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
42+
s" timerLabel = $timerLabel before that timer was stopped.")
43+
}
44+
starts(timerLabel) = tmpTime
45+
}
46+
47+
/**
48+
* Stops a timer and returns the elapsed time in nanoseconds.
49+
*/
50+
def stop(timerLabel: String): Long = {
51+
val tmpTime = System.nanoTime()
52+
if (!starts.contains(timerLabel)) {
53+
throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
54+
s" timerLabel = $timerLabel, but that timer was not started.")
55+
}
56+
val elapsed = tmpTime - starts(timerLabel)
57+
starts.remove(timerLabel)
58+
if (totals.contains(timerLabel)) {
59+
totals(timerLabel) += elapsed
60+
} else {
61+
totals(timerLabel) = elapsed
62+
}
63+
elapsed
64+
}
65+
66+
/**
67+
* Print all timing results.
68+
*/
69+
override def toString: String = {
70+
s"Timing\n" +
71+
totals.map { case (label, elapsed) =>
72+
s" $label: $elapsed"
73+
}.mkString("\n")
74+
}
75+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD
2727
* of size (numFeatures, numBins).
2828
* TODO: ADD DOC
2929
*/
30-
private[tree] class TreePoint(val label: Double, val features: Array[Int]) {
30+
private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable {
3131
}
3232

3333
private[tree] object TreePoint {

0 commit comments

Comments
 (0)