@@ -121,7 +121,7 @@ object DecisionTree extends Serializable {
121
121
122
122
/* Finds the right bin for the given feature*/
123
123
def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
124
- println(" finding bin for labeled point " + labeledPoint.features(featureIndex))
124
+ // println("finding bin for labeled point " + labeledPoint.features(featureIndex))
125
125
// TODO: Do binary search
126
126
for (binIndex <- 0 until strategy.numSplits) {
127
127
val bin = bins(featureIndex)(binIndex)
@@ -227,21 +227,27 @@ object DecisionTree extends Serializable {
227
227
228
228
val binAggregates = binMappedRDD.aggregate(Array .fill[Double ](2 * numSplits* numFeatures* numNodes)(0 ))(binSeqOp,binCombOp)
229
229
println(" binAggregates.length = " + binAggregates.length)
230
- binAggregates.foreach(x => println(x))
230
+ // binAggregates.foreach(x => println(x))
231
231
232
232
233
233
def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]], featureIndex : Int , index : Int , rightNodeAgg : Array [Array [Double ]], topImpurity : Double ): Double = {
234
234
235
235
val left0Count = leftNodeAgg(featureIndex)(2 * index)
236
236
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
237
237
val leftCount = left0Count + left1Count
238
- println(" left0count = " + left0Count + " , left1count = " + left1Count + " , leftCount = " + leftCount)
238
+
239
+ if (leftCount == 0 ) return 0
240
+
241
+ // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
239
242
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
240
243
241
244
val right0Count = rightNodeAgg(featureIndex)(2 * index)
242
245
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
243
246
val rightCount = right0Count + right1Count
244
- println(" right0count = " + right0Count + " , right1count = " + right1Count + " , rightCount = " + rightCount)
247
+
248
+ if (rightCount == 0 ) return 0
249
+
250
+ // println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
245
251
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
246
252
247
253
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
@@ -261,21 +267,21 @@ object DecisionTree extends Serializable {
261
267
def extractLeftRightNodeAggregates (binData : Array [Double ]): (Array [Array [Double ]], Array [Array [Double ]]) = {
262
268
val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
263
269
val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
264
- println(" binData.length = " + binData.length)
265
- println(" binData.sum = " + binData.sum)
270
+ // println("binData.length = " + binData.length)
271
+ // println("binData.sum = " + binData.sum)
266
272
for (featureIndex <- 0 until numFeatures) {
267
- println(" featureIndex = " + featureIndex)
273
+ // println("featureIndex = " + featureIndex)
268
274
val shift = 2 * featureIndex* numSplits
269
275
leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
270
- println(" binData(shift + 0) = " + binData(shift + 0 ))
276
+ // println("binData(shift + 0) = " + binData(shift + 0))
271
277
leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
272
- println(" binData(shift + 1) = " + binData(shift + 1 ))
278
+ // println("binData(shift + 1) = " + binData(shift + 1))
273
279
rightNodeAgg(featureIndex)(2 * (numSplits - 2 )) = binData(shift + (2 * (numSplits - 1 )))
274
- println(binData(shift + (2 * (numSplits - 1 ))))
280
+ // println(binData(shift + (2 * (numSplits - 1))))
275
281
rightNodeAgg(featureIndex)(2 * (numSplits - 2 ) + 1 ) = binData(shift + (2 * (numSplits - 1 )) + 1 )
276
- println(binData(shift + (2 * (numSplits - 1 )) + 1 ))
282
+ // println(binData(shift + (2 * (numSplits - 1)) + 1))
277
283
for (splitIndex <- 1 until numSplits - 1 ) {
278
- println(" splitIndex = " + splitIndex)
284
+ // println("splitIndex = " + splitIndex)
279
285
leftNodeAgg(featureIndex)(2 * splitIndex)
280
286
= binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
281
287
leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
@@ -295,7 +301,7 @@ object DecisionTree extends Serializable {
295
301
296
302
for (featureIndex <- 0 until numFeatures) {
297
303
for (index <- 0 until numSplits - 1 ) {
298
- println(" splitIndex = " + index)
304
+ // println("splitIndex = " + index)
299
305
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
300
306
}
301
307
}
@@ -312,8 +318,8 @@ object DecisionTree extends Serializable {
312
318
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
313
319
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
314
320
315
- println(" gains.size = " + gains.size)
316
- println(" gains(0).size = " + gains(0 ).size)
321
+ // println("gains.size = " + gains.size)
322
+ // println("gains(0).size = " + gains(0).size)
317
323
318
324
val (bestFeatureIndex,bestSplitIndex) = {
319
325
var bestFeatureIndex = 0
@@ -322,7 +328,7 @@ object DecisionTree extends Serializable {
322
328
for (featureIndex <- 0 until numFeatures) {
323
329
for (splitIndex <- 0 until numSplits - 1 ){
324
330
val gain = gains(featureIndex)(splitIndex)
325
- println(" featureIndex = " + featureIndex + " , splitIndex = " + splitIndex + " , gain = " + gain)
331
+ // println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
326
332
if (gain > maxGain) {
327
333
maxGain = gain
328
334
bestFeatureIndex = featureIndex
@@ -335,6 +341,8 @@ object DecisionTree extends Serializable {
335
341
}
336
342
337
343
splits(bestFeatureIndex)(bestSplitIndex)
344
+
345
+ // TODo: Return array of node stats with split and impurity information
338
346
}
339
347
340
348
// Calculate best splits for all nodes at a given level
@@ -388,6 +396,9 @@ object DecisionTree extends Serializable {
388
396
for (featureIndex <- 0 until numFeatures){
389
397
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
390
398
val stride : Double = numSamples.toDouble/ numSplits
399
+
400
+ println(" stride = " + stride)
401
+
391
402
for (index <- 0 until numSplits- 1 ) {
392
403
val sampleIndex = (index+ 1 )* stride.toInt
393
404
val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
0 commit comments