Skip to content

Commit dad0afc

Browse files
committed
decison stump functionality working
Signed-off-by: Manish Amde <manish9ue@gmail.com>
1 parent 03f534c commit dad0afc

File tree

2 files changed

+108
-43
lines changed

2 files changed

+108
-43
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ package org.apache.spark.mllib.tree
2020
import org.apache.spark.SparkContext._
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.mllib.tree.model._
23-
import org.apache.spark.Logging
23+
import org.apache.spark.{SparkContext, Logging}
2424
import org.apache.spark.mllib.regression.LabeledPoint
2525
import org.apache.spark.mllib.tree.model.Split
26+
import org.apache.spark.mllib.tree.impurity.Gini
2627

2728

2829
class DecisionTree(val strategy : Strategy) {
@@ -46,8 +47,13 @@ class DecisionTree(val strategy : Strategy) {
4647
//Find best split for all nodes at a level
4748
val numNodes= scala.math.pow(2,level).toInt
4849
//TODO: Change the input parent impurities values
49-
val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins)
50+
val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins)
51+
for (tmp <- splits_stats_for_level){
52+
println("final best split = " + tmp._1)
53+
}
5054
//TODO: update filters and decision tree model
55+
require(scala.math.pow(2,level)==splits_stats_for_level.length)
56+
5157
}
5258

5359
return new DecisionTreeModel()
@@ -77,7 +83,7 @@ object DecisionTree extends Serializable {
7783
level: Int,
7884
filters : Array[List[Filter]],
7985
splits : Array[Array[Split]],
80-
bins : Array[Array[Bin]]) : Array[Split] = {
86+
bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = {
8187

8288
//Common calculations for multiple nested methods
8389
val numNodes = scala.math.pow(2, level).toInt
@@ -94,8 +100,9 @@ object DecisionTree extends Serializable {
94100
List[Filter]()
95101
} else {
96102
val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex
97-
val parentFilterIndex = nodeFilterIndex / 2
98-
filters(parentFilterIndex)
103+
//val parentFilterIndex = nodeFilterIndex / 2
104+
//TODO: Check left or right filter
105+
filters(nodeFilterIndex)
99106
}
100107
}
101108

@@ -230,30 +237,34 @@ object DecisionTree extends Serializable {
230237
//binAggregates.foreach(x => println(x))
231238

232239

233-
def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = {
240+
def calculateGainForSplit(leftNodeAgg: Array[Array[Double]],
241+
featureIndex: Int,
242+
index: Int,
243+
rightNodeAgg: Array[Array[Double]],
244+
topImpurity: Double) : (Double, Long, Long) = {
234245

235246
val left0Count = leftNodeAgg(featureIndex)(2 * index)
236247
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
237248
val leftCount = left0Count + left1Count
238249

239-
if (leftCount == 0) return 0
240-
241-
//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
242-
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
243-
244250
val right0Count = rightNodeAgg(featureIndex)(2 * index)
245251
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
246252
val rightCount = right0Count + right1Count
247253

248-
if (rightCount == 0) return 0
254+
if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong)
255+
256+
//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
257+
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
258+
259+
if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong)
249260

250261
//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
251262
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
252263

253264
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
254265
val rightWeight = rightCount.toDouble / (leftCount + rightCount)
255266

256-
topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
267+
(topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)
257268

258269
}
259270

@@ -295,9 +306,10 @@ object DecisionTree extends Serializable {
295306
(leftNodeAgg, rightNodeAgg)
296307
}
297308

298-
def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = {
309+
def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double)
310+
: Array[Array[(Double,Long,Long)]] = {
299311

300-
val gains = Array.ofDim[Double](numFeatures, numSplits - 1)
312+
val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1)
301313

302314
for (featureIndex <- 0 until numFeatures) {
303315
for (index <- 0 until numSplits -1) {
@@ -313,40 +325,44 @@ object DecisionTree extends Serializable {
313325
314326
@param binData Array[Double] of size 2*numSplits*numFeatures
315327
*/
316-
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = {
328+
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = {
317329
println("node impurity = " + nodeImpurity)
318330
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
319331
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
320332

321333
//println("gains.size = " + gains.size)
322334
//println("gains(0).size = " + gains(0).size)
323335

324-
val (bestFeatureIndex,bestSplitIndex) = {
336+
val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = {
325337
var bestFeatureIndex = 0
326338
var bestSplitIndex = 0
327339
var maxGain = Double.MinValue
340+
var leftSamples = Long.MinValue
341+
var rightSamples = Long.MinValue
328342
for (featureIndex <- 0 until numFeatures) {
329343
for (splitIndex <- 0 until numSplits - 1){
330344
val gain = gains(featureIndex)(splitIndex)
331345
//println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
332-
if(gain > maxGain) {
333-
maxGain = gain
346+
if(gain._1 > maxGain) {
347+
maxGain = gain._1
348+
leftSamples = gain._2
349+
rightSamples = gain._3
334350
bestFeatureIndex = featureIndex
335351
bestSplitIndex = splitIndex
336-
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain)
352+
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex
353+
+ ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples)
337354
}
338355
}
339356
}
340-
(bestFeatureIndex,bestSplitIndex)
357+
(bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples)
341358
}
342359

343-
splits(bestFeatureIndex)(bestSplitIndex)
344-
345-
//TODo: Return array of node stats with split and impurity information
360+
(splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
361+
//TODO: Return array of node stats with split and impurity information
346362
}
347363

348364
//Calculate best splits for all nodes at a given level
349-
val bestSplits = new Array[Split](numNodes)
365+
val bestSplits = new Array[(Split, Double, Long, Long)](numNodes)
350366
for (node <- 0 until numNodes){
351367
val shift = 2*node*numSplits*numFeatures
352368
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)
@@ -381,9 +397,6 @@ object DecisionTree extends Serializable {
381397
val sampledInput = input.sample(false, fraction, 42).collect()
382398
val numSamples = sampledInput.length
383399

384-
//TODO: Remove this requirement
385-
require(numSamples > numSplits, "length of input samples should be greater than numSplits")
386-
387400
//Find the number of features by looking at the first sample
388401
val numFeatures = input.take(1)(0).features.length
389402

@@ -395,14 +408,22 @@ object DecisionTree extends Serializable {
395408
//Find all splits
396409
for (featureIndex <- 0 until numFeatures){
397410
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
398-
val stride : Double = numSamples.toDouble/numSplits
399-
400-
println("stride = " + stride)
401411

402-
for (index <- 0 until numSplits-1) {
403-
val sampleIndex = (index+1)*stride.toInt
404-
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
405-
splits(featureIndex)(index) = split
412+
if (numSamples < numSplits) {
413+
//TODO: Test this
414+
println("numSamples = " + numSamples + ", less than numSplits = " + numSplits)
415+
for (index <- 0 until numSplits-1) {
416+
val split = new Split(featureIndex,featureSamples(index),"continuous")
417+
splits(featureIndex)(index) = split
418+
}
419+
} else {
420+
val stride : Double = numSamples.toDouble/numSplits
421+
println("stride = " + stride)
422+
for (index <- 0 until numSplits-1) {
423+
val sampleIndex = (index+1)*stride.toInt
424+
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
425+
splits(featureIndex)(index) = split
426+
}
406427
}
407428
}
408429

@@ -430,4 +451,36 @@ object DecisionTree extends Serializable {
430451
}
431452
}
432453

454+
def main(args: Array[String]) {
455+
456+
val sc = new SparkContext(args(0), "DecisionTree")
457+
val data = loadLabeledData(sc, args(1))
458+
459+
val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569)
460+
val model = new DecisionTree(strategy).train(data)
461+
462+
sc.stop()
463+
}
464+
465+
/**
466+
* Load labeled data from a file. The data format used here is
467+
* <L>, <f1> <f2> ...
468+
* where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
469+
*
470+
* @param sc SparkContext
471+
* @param dir Directory to the input data files.
472+
* @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
473+
* the label, and the second element represents the feature values (an array of Double).
474+
*/
475+
def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
476+
sc.textFile(dir).map { line =>
477+
val parts = line.trim().split(",")
478+
val label = parts(0).toDouble
479+
val features = parts.slice(1,parts.length).map(_.toDouble)
480+
LabeledPoint(label, features)
481+
}
482+
}
483+
484+
485+
433486
}

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6868
assert(bins(0).length==100)
6969
assert(splits(0).length==99)
7070
assert(bins(0).length==100)
71-
println(splits(1)(98))
7271
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
7372
assert(bestSplits.length == 1)
74-
println(bestSplits(0))
73+
assert(0==bestSplits(0)._1.feature)
74+
assert(10==bestSplits(0)._1.threshold)
75+
assert(0==bestSplits(0)._2)
76+
assert(10==bestSplits(0)._3)
77+
assert(990==bestSplits(0)._4)
7578
}
7679

7780
test("stump with fixed label 1 for Gini"){
@@ -86,10 +89,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
8689
assert(bins(0).length==100)
8790
assert(splits(0).length==99)
8891
assert(bins(0).length==100)
89-
println(splits(1)(98))
9092
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
9193
assert(bestSplits.length == 1)
92-
println(bestSplits(0))
94+
assert(0==bestSplits(0)._1.feature)
95+
assert(10==bestSplits(0)._1.threshold)
96+
assert(0==bestSplits(0)._2)
97+
assert(10==bestSplits(0)._3)
98+
assert(990==bestSplits(0)._4)
9399
}
94100

95101

@@ -105,10 +111,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
105111
assert(bins(0).length==100)
106112
assert(splits(0).length==99)
107113
assert(bins(0).length==100)
108-
println(splits(1)(98))
109114
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
110115
assert(bestSplits.length == 1)
111-
println(bestSplits(0))
116+
assert(0==bestSplits(0)._1.feature)
117+
assert(10==bestSplits(0)._1.threshold)
118+
assert(0==bestSplits(0)._2)
119+
assert(10==bestSplits(0)._3)
120+
assert(990==bestSplits(0)._4)
112121
}
113122

114123
test("stump with fixed label 1 for Entropy"){
@@ -123,10 +132,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
123132
assert(bins(0).length==100)
124133
assert(splits(0).length==99)
125134
assert(bins(0).length==100)
126-
println(splits(1)(98))
127135
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
128136
assert(bestSplits.length == 1)
129-
println(bestSplits(0))
137+
assert(0==bestSplits(0)._1.feature)
138+
assert(10==bestSplits(0)._1.threshold)
139+
assert(0==bestSplits(0)._2)
140+
assert(10==bestSplits(0)._3)
141+
assert(990==bestSplits(0)._4)
130142
}
131143

132144

0 commit comments

Comments
 (0)