Skip to content

Commit 8bca1e2

Browse files
committed
additional code for creating intermediate RDD
Signed-off-by: Manish Amde <manish9ue@gmail.com>
1 parent 92cedce commit 8bca1e2

File tree

4 files changed

+120
-26
lines changed

4 files changed

+120
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 100 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) {
3737
val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)
3838

3939
//TODO: Level-wise training of tree and obtain Decision Tree model
40-
4140
val maxDepth = strategy.maxDepth
4241

4342
val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
@@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) {
5554

5655
}
5756

58-
object DecisionTree extends Logging {
57+
object DecisionTree extends Serializable {
58+
59+
/*
60+
Returns an Array[Split] of optimal splits for all nodes at a given level
61+
62+
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
63+
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
64+
@param level Level of the tree
65+
@param filters Filter for all nodes at a given level
66+
@param splits possible splits for all features
67+
@param bins possible bins for all features
5968
69+
@return Array[Split] instance for best splits for all nodes at a given level.
70+
*/
6071
def findBestSplits(
6172
input : RDD[LabeledPoint],
6273
strategy: Strategy,
@@ -65,6 +76,16 @@ object DecisionTree extends Logging {
6576
splits : Array[Array[Split]],
6677
bins : Array[Array[Bin]]) : Array[Split] = {
6778

79+
//TODO: Move these calculations outside
80+
val numNodes = scala.math.pow(2, level).toInt
81+
println("numNodes = " + numNodes)
82+
//Find the number of features by looking at the first sample
83+
val numFeatures = input.take(1)(0).features.length
84+
println("numFeatures = " + numFeatures)
85+
val numSplits = strategy.numSplits
86+
println("numSplits = " + numSplits)
87+
88+
/*Find the filters used before reaching the current code*/
6889
def findParentFilters(nodeIndex: Int): List[Filter] = {
6990
if (level == 0) {
7091
List[Filter]()
@@ -75,6 +96,10 @@ object DecisionTree extends Logging {
7596
}
7697
}
7798

99+
/*Find whether the sample is valid input for the current node.
100+
101+
In other words, does it pass through all the filters for the current node.
102+
*/
78103
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
79104

80105
for (filter <- parentFilters) {
@@ -91,79 +116,130 @@ object DecisionTree extends Logging {
91116
true
92117
}
93118

119+
/*Finds the right bin for the given feature*/
94120
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
95-
96121
//TODO: Do binary search
97122
for (binIndex <- 0 until strategy.numSplits) {
98123
val bin = bins(featureIndex)(binIndex)
99-
//TODO: Remove this requirement post basic functional testing
100-
require(bin.lowSplit.feature == featureIndex)
101-
require(bin.highSplit.feature == featureIndex)
124+
//TODO: Remove this requirement post basic functional
102125
val lowThreshold = bin.lowSplit.threshold
103126
val highThreshold = bin.highSplit.threshold
104127
val features = labeledPoint.features
105-
if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
128+
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
106129
return binIndex
107130
}
108131
}
109132
throw new UnknownError("no bin was found.")
110133

111134
}
112-
def findBinsForLevel: Array[Double] = {
113135

114-
val numNodes = scala.math.pow(2, level).toInt
115-
//Find the number of features by looking at the first sample
116-
val numFeatures = input.take(1)(0).features.length
136+
/*Finds bins for all nodes (and all features) at a given level
137+
k features, l nodes
138+
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
139+
Denotes invalid sample for tree by noting bin for feature 1 as -1
140+
*/
141+
def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = {
142+
117143

118-
//TODO: Bit pack more by removing redundant label storage
119144
// calculating bin index and label per feature per node
120-
val arr = new Array[Double](2 * numFeatures * numNodes)
145+
val arr = new Array[Double](1+(numFeatures * numNodes))
146+
arr(0) = labeledPoint.label
121147
for (nodeIndex <- 0 until numNodes) {
122148
val parentFilters = findParentFilters(nodeIndex)
123149
//Find out whether the sample qualifies for the particular node
124150
val sampleValid = isSampleValid(parentFilters, labeledPoint)
125-
val shift = 2 * numFeatures * nodeIndex
126-
if (sampleValid) {
151+
val shift = 1 + numFeatures * nodeIndex
152+
if (!sampleValid) {
127153
//Add to invalid bin index -1
128-
for (featureIndex <- shift until (shift + numFeatures) by 2) {
129-
arr(featureIndex + 1) = -1
130-
arr(featureIndex + 2) = labeledPoint.label
154+
for (featureIndex <- 0 until numFeatures) {
155+
arr(shift+featureIndex) = -1
156+
//TODO: Break since marking one bin is sufficient
131157
}
132158
} else {
133159
for (featureIndex <- 0 until numFeatures) {
134-
arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint)
135-
arr(shift + (featureIndex * 2) + 2) = labeledPoint.label
160+
//println("shift+featureIndex =" + (shift+featureIndex))
161+
arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
136162
}
137163
}
138164

139165
}
140166
arr
141167
}
142168

143-
val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
169+
/*
170+
Performs a sequential aggreation over a partition
171+
172+
@param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
173+
and 3*numSplits*numFeatures*numNodes for regression
174+
@param arr Array[Double] of size 1+(numFeatures*numNodes)
175+
@return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
176+
and 3*numSplits*numFeatures*numNodes for regression
177+
*/
178+
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
179+
for (node <- 0 until numNodes) {
180+
val validSignalIndex = 1+numFeatures*node
181+
val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false
182+
if(isSampleValidForNode) {
183+
for (feature <- 0 until numFeatures){
184+
val arrShift = 1 + numFeatures*node
185+
val aggShift = numSplits*numFeatures*node
186+
val arrIndex = arrShift + feature
187+
val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt
188+
agg(aggIndex) = agg(aggIndex) + 1
189+
}
190+
}
191+
}
192+
agg
193+
}
194+
195+
def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = {
196+
par1
197+
}
198+
199+
println("input = " + input.count)
200+
val binMappedRDD = input.map(x => findBinsForLevel(x))
201+
println("binMappedRDD.count = " + binMappedRDD.count)
144202
//calculate bin aggregates
203+
204+
val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)
205+
145206
//find best split
207+
println("binAggregates.length = " + binAggregates.length)
146208

147209

148-
Array[Split]()
210+
val bestSplits = new Array[Split](numNodes)
211+
for (node <- 0 until numNodes){
212+
val binsForNode = binAggregates.slice(node,numSplits*node)
213+
}
214+
215+
bestSplits
149216
}
150217

218+
/*
219+
Returns split and bins for decision tree calculation.
220+
221+
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
222+
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
223+
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
224+
Array[Array[Bin]] of size (numFeatures,numSplits1)
225+
*/
151226
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
152227

153228
val numSplits = strategy.numSplits
154-
logDebug("numSplits = " + numSplits)
229+
println("numSplits = " + numSplits)
155230

156231
//Calculate the number of sample for approximate quantile calculation
157232
//TODO: Justify this calculation
158233
val requiredSamples = numSplits*numSplits
159234
val count = input.count()
160235
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
161-
logDebug("fraction of data used for calculating quantiles = " + fraction)
236+
println("fraction of data used for calculating quantiles = " + fraction)
162237

163238
//sampled input for RDD calculation
164239
val sampledInput = input.sample(false, fraction, 42).collect()
165240
val numSamples = sampledInput.length
166241

242+
//TODO: Remove this requirement
167243
require(numSamples > numSplits, "length of input samples should be greater than numSplits")
168244

169245
//Find the number of features by looking at the first sample

mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree
1818

1919
import org.apache.spark.mllib.tree.impurity.Impurity
2020

21-
class Strategy (
21+
case class Strategy (
2222
val kind : String,
2323
val impurity : Impurity,
2424
val maxDepth : Int,

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.apache.spark.mllib.tree.impurity
1818

19-
trait Impurity {
19+
trait Impurity extends Serializable {
2020

2121
def calculate(c0 : Double, c1 : Double): Double
2222

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.jblas._
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.mllib.regression.LabeledPoint
3030
import org.apache.spark.mllib.tree.impurity.Gini
31+
import org.apache.spark.mllib.tree.model.Filter
3132

3233
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
3334

@@ -54,6 +55,23 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
5455
assert(bins(0).length==100)
5556
println(splits(1)(98))
5657
}
58+
59+
test("stump"){
60+
val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints()
61+
assert(arr.length == 1000)
62+
val rdd = sc.parallelize(arr)
63+
val strategy = new Strategy("regression",Gini,3,100,"sort")
64+
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
65+
assert(splits.length==2)
66+
assert(splits(0).length==99)
67+
assert(bins.length==2)
68+
assert(bins(0).length==100)
69+
assert(splits(0).length==99)
70+
assert(bins(0).length==100)
71+
println(splits(1)(98))
72+
DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins)
73+
}
74+
5775
}
5876

5977
object DecisionTreeSuite {

0 commit comments

Comments
 (0)