Skip to content

Commit 03f534c

Browse files
committed
some more tests
Signed-off-by: Manish Amde <manish9ue@gmail.com>
1 parent 0012a77 commit 03f534c

File tree

2 files changed

+102
-23
lines changed

2 files changed

+102
-23
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object DecisionTree extends Serializable {
121121

122122
/*Finds the right bin for the given feature*/
123123
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
124-
println("finding bin for labeled point " + labeledPoint.features(featureIndex))
124+
//println("finding bin for labeled point " + labeledPoint.features(featureIndex))
125125
//TODO: Do binary search
126126
for (binIndex <- 0 until strategy.numSplits) {
127127
val bin = bins(featureIndex)(binIndex)
@@ -227,21 +227,27 @@ object DecisionTree extends Serializable {
227227

228228
val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)
229229
println("binAggregates.length = " + binAggregates.length)
230-
binAggregates.foreach(x => println(x))
230+
//binAggregates.foreach(x => println(x))
231231

232232

233233
def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = {
234234

235235
val left0Count = leftNodeAgg(featureIndex)(2 * index)
236236
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
237237
val leftCount = left0Count + left1Count
238-
println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
238+
239+
if (leftCount == 0) return 0
240+
241+
//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
239242
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
240243

241244
val right0Count = rightNodeAgg(featureIndex)(2 * index)
242245
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
243246
val rightCount = right0Count + right1Count
244-
println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
247+
248+
if (rightCount == 0) return 0
249+
250+
//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
245251
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
246252

247253
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
@@ -261,21 +267,21 @@ object DecisionTree extends Serializable {
261267
def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
262268
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
263269
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
264-
println("binData.length = " + binData.length)
265-
println("binData.sum = " + binData.sum)
270+
//println("binData.length = " + binData.length)
271+
//println("binData.sum = " + binData.sum)
266272
for (featureIndex <- 0 until numFeatures) {
267-
println("featureIndex = " + featureIndex)
273+
//println("featureIndex = " + featureIndex)
268274
val shift = 2*featureIndex*numSplits
269275
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
270-
println("binData(shift + 0) = " + binData(shift + 0))
276+
//println("binData(shift + 0) = " + binData(shift + 0))
271277
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
272-
println("binData(shift + 1) = " + binData(shift + 1))
278+
//println("binData(shift + 1) = " + binData(shift + 1))
273279
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
274-
println(binData(shift + (2 * (numSplits - 1))))
280+
//println(binData(shift + (2 * (numSplits - 1))))
275281
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
276-
println(binData(shift + (2 * (numSplits - 1)) + 1))
282+
//println(binData(shift + (2 * (numSplits - 1)) + 1))
277283
for (splitIndex <- 1 until numSplits - 1) {
278-
println("splitIndex = " + splitIndex)
284+
//println("splitIndex = " + splitIndex)
279285
leftNodeAgg(featureIndex)(2 * splitIndex)
280286
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
281287
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
@@ -295,7 +301,7 @@ object DecisionTree extends Serializable {
295301

296302
for (featureIndex <- 0 until numFeatures) {
297303
for (index <- 0 until numSplits -1) {
298-
println("splitIndex = " + index)
304+
//println("splitIndex = " + index)
299305
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
300306
}
301307
}
@@ -312,8 +318,8 @@ object DecisionTree extends Serializable {
312318
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
313319
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
314320

315-
println("gains.size = " + gains.size)
316-
println("gains(0).size = " + gains(0).size)
321+
//println("gains.size = " + gains.size)
322+
//println("gains(0).size = " + gains(0).size)
317323

318324
val (bestFeatureIndex,bestSplitIndex) = {
319325
var bestFeatureIndex = 0
@@ -322,7 +328,7 @@ object DecisionTree extends Serializable {
322328
for (featureIndex <- 0 until numFeatures) {
323329
for (splitIndex <- 0 until numSplits - 1){
324330
val gain = gains(featureIndex)(splitIndex)
325-
println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
331+
//println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
326332
if(gain > maxGain) {
327333
maxGain = gain
328334
bestFeatureIndex = featureIndex
@@ -335,6 +341,8 @@ object DecisionTree extends Serializable {
335341
}
336342

337343
splits(bestFeatureIndex)(bestSplitIndex)
344+
345+
//TODo: Return array of node stats with split and impurity information
338346
}
339347

340348
//Calculate best splits for all nodes at a given level
@@ -388,6 +396,9 @@ object DecisionTree extends Serializable {
388396
for (featureIndex <- 0 until numFeatures){
389397
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
390398
val stride : Double = numSamples.toDouble/numSplits
399+
400+
println("stride = " + stride)
401+
391402
for (index <- 0 until numSplits-1) {
392403
val sampleIndex = (index+1)*stride.toInt
393404
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.SparkContext._
2727
import org.jblas._
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.mllib.regression.LabeledPoint
30-
import org.apache.spark.mllib.tree.impurity.Gini
30+
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
3131
import org.apache.spark.mllib.tree.model.Filter
3232

3333
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
@@ -44,7 +44,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
4444
}
4545

4646
test("split and bin calculation"){
47-
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
47+
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
4848
assert(arr.length == 1000)
4949
val rdd = sc.parallelize(arr)
5050
val strategy = new Strategy("regression",Gini,3,100,"sort")
@@ -56,8 +56,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
5656
println(splits(1)(98))
5757
}
5858

59-
test("stump"){
60-
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
59+
test("stump with fixed label 0 for Gini"){
60+
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
6161
assert(arr.length == 1000)
6262
val rdd = sc.parallelize(arr)
6363
val strategy = new Strategy("regression",Gini,3,100,"sort")
@@ -69,17 +69,85 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6969
assert(splits(0).length==99)
7070
assert(bins(0).length==100)
7171
println(splits(1)(98))
72-
DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
72+
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
73+
assert(bestSplits.length == 1)
74+
println(bestSplits(0))
7375
}
7476

77+
test("stump with fixed label 1 for Gini"){
78+
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
79+
assert(arr.length == 1000)
80+
val rdd = sc.parallelize(arr)
81+
val strategy = new Strategy("regression",Gini,3,100,"sort")
82+
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
83+
assert(splits.length==2)
84+
assert(splits(0).length==99)
85+
assert(bins.length==2)
86+
assert(bins(0).length==100)
87+
assert(splits(0).length==99)
88+
assert(bins(0).length==100)
89+
println(splits(1)(98))
90+
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
91+
assert(bestSplits.length == 1)
92+
println(bestSplits(0))
93+
}
94+
95+
96+
test("stump with fixed label 0 for Entropy"){
97+
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
98+
assert(arr.length == 1000)
99+
val rdd = sc.parallelize(arr)
100+
val strategy = new Strategy("regression",Entropy,3,100,"sort")
101+
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
102+
assert(splits.length==2)
103+
assert(splits(0).length==99)
104+
assert(bins.length==2)
105+
assert(bins(0).length==100)
106+
assert(splits(0).length==99)
107+
assert(bins(0).length==100)
108+
println(splits(1)(98))
109+
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
110+
assert(bestSplits.length == 1)
111+
println(bestSplits(0))
112+
}
113+
114+
test("stump with fixed label 1 for Entropy"){
115+
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
116+
assert(arr.length == 1000)
117+
val rdd = sc.parallelize(arr)
118+
val strategy = new Strategy("regression",Entropy,3,100,"sort")
119+
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
120+
assert(splits.length==2)
121+
assert(splits(0).length==99)
122+
assert(bins.length==2)
123+
assert(bins(0).length==100)
124+
assert(splits(0).length==99)
125+
assert(bins(0).length==100)
126+
println(splits(1)(98))
127+
val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
128+
assert(bestSplits.length == 1)
129+
println(bestSplits(0))
130+
}
131+
132+
75133
}
76134

77135
object DecisionTreeSuite {
78136

79-
def generateReverseOrderedLabeledPoints() : Array[LabeledPoint] = {
137+
def generateOrderedLabeledPointsWithLabel0() : Array[LabeledPoint] = {
138+
val arr = new Array[LabeledPoint](1000)
139+
for (i <- 0 until 1000){
140+
val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
141+
arr(i) = lp
142+
}
143+
arr
144+
}
145+
146+
147+
def generateOrderedLabeledPointsWithLabel1() : Array[LabeledPoint] = {
80148
val arr = new Array[LabeledPoint](1000)
81149
for (i <- 0 until 1000){
82-
val lp = new LabeledPoint(1.0,Array(i.toDouble,1000.0-i))
150+
val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
83151
arr(i) = lp
84152
}
85153
arr

0 commit comments

Comments
 (0)