@@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) {
37
37
val (splits, bins) = DecisionTree .find_splits_bins(input, strategy)
38
38
39
39
// TODO: Level-wise training of tree and obtain Decision Tree model
40
-
41
40
val maxDepth = strategy.maxDepth
42
41
43
42
val maxNumNodes = scala.math.pow(2 ,maxDepth).toInt - 1
@@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) {
55
54
56
55
}
57
56
58
- object DecisionTree extends Logging {
57
+ object DecisionTree extends Serializable {
58
+
59
+ /*
60
+ Returns an Array[Split] of optimal splits for all nodes at a given level
61
+
62
+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
63
+ @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
64
+ @param level Level of the tree
65
+ @param filters Filter for all nodes at a given level
66
+ @param splits possible splits for all features
67
+ @param bins possible bins for all features
59
68
69
+ @return Array[Split] instance for best splits for all nodes at a given level.
70
+ */
60
71
def findBestSplits (
61
72
input : RDD [LabeledPoint ],
62
73
strategy : Strategy ,
@@ -65,6 +76,16 @@ object DecisionTree extends Logging {
65
76
splits : Array [Array [Split ]],
66
77
bins : Array [Array [Bin ]]) : Array [Split ] = {
67
78
79
+ // TODO: Move these calculations outside
80
+ val numNodes = scala.math.pow(2 , level).toInt
81
+ println(" numNodes = " + numNodes)
82
+ // Find the number of features by looking at the first sample
83
+ val numFeatures = input.take(1 )(0 ).features.length
84
+ println(" numFeatures = " + numFeatures)
85
+ val numSplits = strategy.numSplits
86
+ println(" numSplits = " + numSplits)
87
+
88
+ /* Find the filters used before reaching the current code*/
68
89
def findParentFilters (nodeIndex : Int ): List [Filter ] = {
69
90
if (level == 0 ) {
70
91
List [Filter ]()
@@ -75,6 +96,10 @@ object DecisionTree extends Logging {
75
96
}
76
97
}
77
98
99
+ /* Find whether the sample is valid input for the current node.
100
+
101
+ In other words, does it pass through all the filters for the current node.
102
+ */
78
103
def isSampleValid (parentFilters : List [Filter ], labeledPoint : LabeledPoint ): Boolean = {
79
104
80
105
for (filter <- parentFilters) {
@@ -91,79 +116,130 @@ object DecisionTree extends Logging {
91
116
true
92
117
}
93
118
119
+ /* Finds the right bin for the given feature*/
94
120
def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
95
-
96
121
// TODO: Do binary search
97
122
for (binIndex <- 0 until strategy.numSplits) {
98
123
val bin = bins(featureIndex)(binIndex)
99
- // TODO: Remove this requirement post basic functional testing
100
- require(bin.lowSplit.feature == featureIndex)
101
- require(bin.highSplit.feature == featureIndex)
124
+ // TODO: Remove this requirement post basic functional
102
125
val lowThreshold = bin.lowSplit.threshold
103
126
val highThreshold = bin.highSplit.threshold
104
127
val features = labeledPoint.features
105
- if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
128
+ if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
106
129
return binIndex
107
130
}
108
131
}
109
132
throw new UnknownError (" no bin was found." )
110
133
111
134
}
112
- def findBinsForLevel : Array [Double ] = {
113
135
114
- val numNodes = scala.math.pow(2 , level).toInt
115
- // Find the number of features by looking at the first sample
116
- val numFeatures = input.take(1 )(0 ).features.length
136
+ /* Finds bins for all nodes (and all features) at a given level
137
+ k features, l nodes
138
+ Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
139
+ Denotes invalid sample for tree by noting bin for feature 1 as -1
140
+ */
141
+ def findBinsForLevel (labeledPoint : LabeledPoint ) : Array [Double ] = {
142
+
117
143
118
- // TODO: Bit pack more by removing redundant label storage
119
144
// calculating bin index and label per feature per node
120
- val arr = new Array [Double ](2 * numFeatures * numNodes)
145
+ val arr = new Array [Double ](1 + (numFeatures * numNodes))
146
+ arr(0 ) = labeledPoint.label
121
147
for (nodeIndex <- 0 until numNodes) {
122
148
val parentFilters = findParentFilters(nodeIndex)
123
149
// Find out whether the sample qualifies for the particular node
124
150
val sampleValid = isSampleValid(parentFilters, labeledPoint)
125
- val shift = 2 * numFeatures * nodeIndex
126
- if (sampleValid) {
151
+ val shift = 1 + numFeatures * nodeIndex
152
+ if (! sampleValid) {
127
153
// Add to invalid bin index -1
128
- for (featureIndex <- shift until (shift + numFeatures) by 2 ) {
129
- arr(featureIndex + 1 ) = - 1
130
- arr(featureIndex + 2 ) = labeledPoint.label
154
+ for (featureIndex <- 0 until numFeatures) {
155
+ arr(shift + featureIndex ) = - 1
156
+ // TODO: Break since marking one bin is sufficient
131
157
}
132
158
} else {
133
159
for (featureIndex <- 0 until numFeatures) {
134
- arr( shift + ( featureIndex * 2 ) + 1 ) = findBin( featureIndex, labeledPoint )
135
- arr(shift + ( featureIndex * 2 ) + 2 ) = labeledPoint.label
160
+ // println(" shift+ featureIndex =" + (shift+ featureIndex) )
161
+ arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
136
162
}
137
163
}
138
164
139
165
}
140
166
arr
141
167
}
142
168
143
- val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
169
+ /*
170
+ Performs a sequential aggreation over a partition
171
+
172
+ @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
173
+ and 3*numSplits*numFeatures*numNodes for regression
174
+ @param arr Array[Double] of size 1+(numFeatures*numNodes)
175
+ @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
176
+ and 3*numSplits*numFeatures*numNodes for regression
177
+ */
178
+ def binSeqOp (agg : Array [Double ], arr : Array [Double ]) : Array [Double ] = {
179
+ for (node <- 0 until numNodes) {
180
+ val validSignalIndex = 1 + numFeatures* node
181
+ val isSampleValidForNode = if (arr(validSignalIndex) != - 1 ) true else false
182
+ if (isSampleValidForNode) {
183
+ for (feature <- 0 until numFeatures){
184
+ val arrShift = 1 + numFeatures* node
185
+ val aggShift = numSplits* numFeatures* node
186
+ val arrIndex = arrShift + feature
187
+ val aggIndex = aggShift + feature* numSplits + arr(arrIndex).toInt
188
+ agg(aggIndex) = agg(aggIndex) + 1
189
+ }
190
+ }
191
+ }
192
+ agg
193
+ }
194
+
195
+ def binCombOp (par1 : Array [Double ], par2 : Array [Double ]) : Array [Double ] = {
196
+ par1
197
+ }
198
+
199
+ println(" input = " + input.count)
200
+ val binMappedRDD = input.map(x => findBinsForLevel(x))
201
+ println(" binMappedRDD.count = " + binMappedRDD.count)
144
202
// calculate bin aggregates
203
+
204
+ val binAggregates = binMappedRDD.aggregate(Array .fill[Double ](numSplits* numFeatures* numNodes)(0 ))(binSeqOp,binCombOp)
205
+
145
206
// find best split
207
+ println(" binAggregates.length = " + binAggregates.length)
146
208
147
209
148
- Array [Split ]()
210
+ val bestSplits = new Array [Split ](numNodes)
211
+ for (node <- 0 until numNodes){
212
+ val binsForNode = binAggregates.slice(node,numSplits* node)
213
+ }
214
+
215
+ bestSplits
149
216
}
150
217
218
+ /*
219
+ Returns split and bins for decision tree calculation.
220
+
221
+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
222
+ @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
223
+ @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
224
+ Array[Array[Bin]] of size (numFeatures,numSplits1)
225
+ */
151
226
def find_splits_bins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
152
227
153
228
val numSplits = strategy.numSplits
154
- logDebug (" numSplits = " + numSplits)
229
+ println (" numSplits = " + numSplits)
155
230
156
231
// Calculate the number of sample for approximate quantile calculation
157
232
// TODO: Justify this calculation
158
233
val requiredSamples = numSplits* numSplits
159
234
val count = input.count()
160
235
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
161
- logDebug (" fraction of data used for calculating quantiles = " + fraction)
236
+ println (" fraction of data used for calculating quantiles = " + fraction)
162
237
163
238
// sampled input for RDD calculation
164
239
val sampledInput = input.sample(false , fraction, 42 ).collect()
165
240
val numSamples = sampledInput.length
166
241
242
+ // TODO: Remove this requirement
167
243
require(numSamples > numSplits, " length of input samples should be greater than numSplits" )
168
244
169
245
// Find the number of features by looking at the first sample
0 commit comments