Skip to content

Commit c0e522b

Browse files
committed
updated predict and split threshold logic
Signed-off-by: Manish Amde <manish9ue@gmail.com>
1 parent b09dc98 commit c0e522b

File tree

6 files changed

+15
-13
lines changed

6 files changed

+15
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ object DecisionTree extends Serializable with Logging {
211211
val lowThreshold = bin.lowSplit.threshold
212212
val highThreshold = bin.highSplit.threshold
213213
val features = labeledPoint.features
214-
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
214+
if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) {
215215
return binIndex
216216
}
217217
}
@@ -400,7 +400,8 @@ object DecisionTree extends Serializable with Logging {
400400
}
401401
}
402402

403-
val predict = leftCount / (leftCount + rightCount)
403+
//val predict = leftCount / (leftCount + rightCount)
404+
val predict = (left1Count + right1Count) / (leftCount + rightCount)
404405

405406
new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict)
406407
}
@@ -672,8 +673,8 @@ object DecisionTree extends Serializable with Logging {
672673

673674
//Find all bins
674675
for (featureIndex <- 0 until numFeatures){
675-
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
676-
if (isFeatureContinous) { //bins for categorical variables are already assigned
676+
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
677+
if (isFeatureContinuous) { //bins for categorical variables are already assigned
677678
bins(featureIndex)(0)
678679
= new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue)
679680
for (index <- 1 until numBins - 1){

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ object DecisionTreeRunner extends Logging {
133133
//TODO: Make these generic MLTable metrics
134134
def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = {
135135
val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean()
136-
println("meanSumOfSquares = " + meanSumOfSquares)
137136
meanSumOfSquares
138137
}
139138

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl
2424
def predict(features : Array[Double]) = {
2525
algo match {
2626
case Classification => {
27-
if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0
27+
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
2828
}
2929
case Regression => {
3030
topNode.predictIfLeaf(features)

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ class InformationGainStats(val gain : Double,
2424
//val rightSamples : Long
2525
val predict : Double) extends Serializable {
2626

27-
override def toString =
28-
"gain = " + gain + ", impurity = " + impurity + ", left impurity = "
29-
+ leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict
27+
override def toString = {
28+
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
29+
.format(gain, impurity, leftImpurity, rightImpurity, predict)
30+
}
3031

3132

3233
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Node ( val id : Int,
3434
def build(nodes : Array[Node]) : Unit = {
3535

3636
logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt )
37+
logDebug("id = " + id + ", split = " + split)
3738
logDebug("stats = " + stats)
3839
logDebug("predict = " + predict)
3940
if (!isLeaf) {

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
157157
assert(0==bestSplits(0)._2.gain)
158158
assert(0==bestSplits(0)._2.leftImpurity)
159159
assert(0==bestSplits(0)._2.rightImpurity)
160-
assert(0.01==bestSplits(0)._2.predict)
160+
println(bestSplits(0)._2.predict)
161161
}
162162

163163
test("stump with fixed label 1 for Gini"){
@@ -181,7 +181,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
181181
assert(0==bestSplits(0)._2.gain)
182182
assert(0==bestSplits(0)._2.leftImpurity)
183183
assert(0==bestSplits(0)._2.rightImpurity)
184-
assert(0.01==bestSplits(0)._2.predict)
184+
assert(1==bestSplits(0)._2.predict)
185185

186186
}
187187

@@ -207,7 +207,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
207207
assert(0==bestSplits(0)._2.gain)
208208
assert(0==bestSplits(0)._2.leftImpurity)
209209
assert(0==bestSplits(0)._2.rightImpurity)
210-
assert(0.01==bestSplits(0)._2.predict)
210+
assert(0==bestSplits(0)._2.predict)
211211
}
212212

213213
test("stump with fixed label 1 for Entropy"){
@@ -231,7 +231,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
231231
assert(0==bestSplits(0)._2.gain)
232232
assert(0==bestSplits(0)._2.leftImpurity)
233233
assert(0==bestSplits(0)._2.rightImpurity)
234-
assert(0.01==bestSplits(0)._2.predict)
234+
assert(1==bestSplits(0)._2.predict)
235235
}
236236

237237

0 commit comments

Comments
 (0)