@@ -20,9 +20,10 @@ package org.apache.spark.mllib.tree
20
20
import org .apache .spark .SparkContext ._
21
21
import org .apache .spark .rdd .RDD
22
22
import org .apache .spark .mllib .tree .model ._
23
- import org .apache .spark .Logging
23
+ import org .apache .spark .{ SparkContext , Logging }
24
24
import org .apache .spark .mllib .regression .LabeledPoint
25
25
import org .apache .spark .mllib .tree .model .Split
26
+ import org .apache .spark .mllib .tree .impurity .Gini
26
27
27
28
28
29
class DecisionTree (val strategy : Strategy ) {
@@ -46,8 +47,13 @@ class DecisionTree(val strategy : Strategy) {
46
47
// Find best split for all nodes at a level
47
48
val numNodes = scala.math.pow(2 ,level).toInt
48
49
// TODO: Change the input parent impurities values
49
- val bestSplits = DecisionTree .findBestSplits(input, Array (0.0 ), strategy, level, filters,splits,bins)
50
+ val splits_stats_for_level = DecisionTree .findBestSplits(input, Array (2.0 ), strategy, level, filters,splits,bins)
51
+ for (tmp <- splits_stats_for_level){
52
+ println(" final best split = " + tmp._1)
53
+ }
50
54
// TODO: update filters and decision tree model
55
+ require(scala.math.pow(2 ,level)== splits_stats_for_level.length)
56
+
51
57
}
52
58
53
59
return new DecisionTreeModel ()
@@ -77,7 +83,7 @@ object DecisionTree extends Serializable {
77
83
level : Int ,
78
84
filters : Array [List [Filter ]],
79
85
splits : Array [Array [Split ]],
80
- bins : Array [Array [Bin ]]) : Array [Split ] = {
86
+ bins : Array [Array [Bin ]]) : Array [( Split , Double , Long , Long ) ] = {
81
87
82
88
// Common calculations for multiple nested methods
83
89
val numNodes = scala.math.pow(2 , level).toInt
@@ -94,8 +100,9 @@ object DecisionTree extends Serializable {
94
100
List [Filter ]()
95
101
} else {
96
102
val nodeFilterIndex = scala.math.pow(2 , level).toInt + nodeIndex
97
- val parentFilterIndex = nodeFilterIndex / 2
98
- filters(parentFilterIndex)
103
+ // val parentFilterIndex = nodeFilterIndex / 2
104
+ // TODO: Check left or right filter
105
+ filters(nodeFilterIndex)
99
106
}
100
107
}
101
108
@@ -230,30 +237,34 @@ object DecisionTree extends Serializable {
230
237
// binAggregates.foreach(x => println(x))
231
238
232
239
233
- def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]], featureIndex : Int , index : Int , rightNodeAgg : Array [Array [Double ]], topImpurity : Double ): Double = {
240
+ def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]],
241
+ featureIndex : Int ,
242
+ index : Int ,
243
+ rightNodeAgg : Array [Array [Double ]],
244
+ topImpurity : Double ) : (Double , Long , Long ) = {
234
245
235
246
val left0Count = leftNodeAgg(featureIndex)(2 * index)
236
247
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
237
248
val leftCount = left0Count + left1Count
238
249
239
- if (leftCount == 0 ) return 0
240
-
241
- // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
242
- val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
243
-
244
250
val right0Count = rightNodeAgg(featureIndex)(2 * index)
245
251
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
246
252
val rightCount = right0Count + right1Count
247
253
248
- if (rightCount == 0 ) return 0
254
+ if (leftCount == 0 ) return (0 , leftCount.toLong, rightCount.toLong)
255
+
256
+ // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
257
+ val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
258
+
259
+ if (rightCount == 0 ) return (0 , leftCount.toLong, rightCount.toLong)
249
260
250
261
// println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
251
262
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
252
263
253
264
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
254
265
val rightWeight = rightCount.toDouble / (leftCount + rightCount)
255
266
256
- topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
267
+ ( topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)
257
268
258
269
}
259
270
@@ -295,9 +306,10 @@ object DecisionTree extends Serializable {
295
306
(leftNodeAgg, rightNodeAgg)
296
307
}
297
308
298
- def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double ): Array [Array [Double ]] = {
309
+ def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double )
310
+ : Array [Array [(Double ,Long ,Long )]] = {
299
311
300
- val gains = Array .ofDim[Double ](numFeatures, numSplits - 1 )
312
+ val gains = Array .ofDim[( Double , Long , Long ) ](numFeatures, numSplits - 1 )
301
313
302
314
for (featureIndex <- 0 until numFeatures) {
303
315
for (index <- 0 until numSplits - 1 ) {
@@ -313,40 +325,44 @@ object DecisionTree extends Serializable {
313
325
314
326
@param binData Array[Double] of size 2*numSplits*numFeatures
315
327
*/
316
- def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : Split = {
328
+ def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : ( Split , Double , Long , Long ) = {
317
329
println(" node impurity = " + nodeImpurity)
318
330
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
319
331
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
320
332
321
333
// println("gains.size = " + gains.size)
322
334
// println("gains(0).size = " + gains(0).size)
323
335
324
- val (bestFeatureIndex,bestSplitIndex) = {
336
+ val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount ) = {
325
337
var bestFeatureIndex = 0
326
338
var bestSplitIndex = 0
327
339
var maxGain = Double .MinValue
340
+ var leftSamples = Long .MinValue
341
+ var rightSamples = Long .MinValue
328
342
for (featureIndex <- 0 until numFeatures) {
329
343
for (splitIndex <- 0 until numSplits - 1 ){
330
344
val gain = gains(featureIndex)(splitIndex)
331
345
// println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
332
- if (gain > maxGain) {
333
- maxGain = gain
346
+ if (gain._1 > maxGain) {
347
+ maxGain = gain._1
348
+ leftSamples = gain._2
349
+ rightSamples = gain._3
334
350
bestFeatureIndex = featureIndex
335
351
bestSplitIndex = splitIndex
336
- println(" bestFeatureIndex = " + bestFeatureIndex + " , bestSplitIndex = " + bestSplitIndex + " , maxGain = " + maxGain)
352
+ println(" bestFeatureIndex = " + bestFeatureIndex + " , bestSplitIndex = " + bestSplitIndex
353
+ + " , maxGain = " + maxGain + " , leftSamples = " + leftSamples + " ,rightSamples = " + rightSamples)
337
354
}
338
355
}
339
356
}
340
- (bestFeatureIndex,bestSplitIndex)
357
+ (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples )
341
358
}
342
359
343
- splits(bestFeatureIndex)(bestSplitIndex)
344
-
345
- // TODo: Return array of node stats with split and impurity information
360
+ (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
361
+ // TODO: Return array of node stats with split and impurity information
346
362
}
347
363
348
364
// Calculate best splits for all nodes at a given level
349
- val bestSplits = new Array [Split ](numNodes)
365
+ val bestSplits = new Array [( Split , Double , Long , Long ) ](numNodes)
350
366
for (node <- 0 until numNodes){
351
367
val shift = 2 * node* numSplits* numFeatures
352
368
val binsForNode = binAggregates.slice(shift,shift+ 2 * numSplits* numFeatures)
@@ -381,9 +397,6 @@ object DecisionTree extends Serializable {
381
397
val sampledInput = input.sample(false , fraction, 42 ).collect()
382
398
val numSamples = sampledInput.length
383
399
384
- // TODO: Remove this requirement
385
- require(numSamples > numSplits, " length of input samples should be greater than numSplits" )
386
-
387
400
// Find the number of features by looking at the first sample
388
401
val numFeatures = input.take(1 )(0 ).features.length
389
402
@@ -395,14 +408,22 @@ object DecisionTree extends Serializable {
395
408
// Find all splits
396
409
for (featureIndex <- 0 until numFeatures){
397
410
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
398
- val stride : Double = numSamples.toDouble/ numSplits
399
-
400
- println(" stride = " + stride)
401
411
402
- for (index <- 0 until numSplits- 1 ) {
403
- val sampleIndex = (index+ 1 )* stride.toInt
404
- val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
405
- splits(featureIndex)(index) = split
412
+ if (numSamples < numSplits) {
413
+ // TODO: Test this
414
+ println(" numSamples = " + numSamples + " , less than numSplits = " + numSplits)
415
+ for (index <- 0 until numSplits- 1 ) {
416
+ val split = new Split (featureIndex,featureSamples(index)," continuous" )
417
+ splits(featureIndex)(index) = split
418
+ }
419
+ } else {
420
+ val stride : Double = numSamples.toDouble/ numSplits
421
+ println(" stride = " + stride)
422
+ for (index <- 0 until numSplits- 1 ) {
423
+ val sampleIndex = (index+ 1 )* stride.toInt
424
+ val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
425
+ splits(featureIndex)(index) = split
426
+ }
406
427
}
407
428
}
408
429
@@ -430,4 +451,36 @@ object DecisionTree extends Serializable {
430
451
}
431
452
}
432
453
454
+ def main (args : Array [String ]) {
455
+
456
+ val sc = new SparkContext (args(0 ), " DecisionTree" )
457
+ val data = loadLabeledData(sc, args(1 ))
458
+
459
+ val strategy = new Strategy (kind = " classification" , impurity = Gini , maxDepth = 2 , numSplits = 569 )
460
+ val model = new DecisionTree (strategy).train(data)
461
+
462
+ sc.stop()
463
+ }
464
+
465
+ /**
466
+ * Load labeled data from a file. The data format used here is
467
+ * <L>, <f1> <f2> ...
468
+ * where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
469
+ *
470
+ * @param sc SparkContext
471
+ * @param dir Directory to the input data files.
472
+ * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
473
+ * the label, and the second element represents the feature values (an array of Double).
474
+ */
475
+ def loadLabeledData (sc : SparkContext , dir : String ): RDD [LabeledPoint ] = {
476
+ sc.textFile(dir).map { line =>
477
+ val parts = line.trim().split(" ," )
478
+ val label = parts(0 ).toDouble
479
+ val features = parts.slice(1 ,parts.length).map(_.toDouble)
480
+ LabeledPoint (label, features)
481
+ }
482
+ }
483
+
484
+
485
+
433
486
}
0 commit comments