Skip to content

Commit d7c53ee

Browse files
committed
Added more doc for ImpurityAggregator
1 parent a40f8f1 commit d7c53ee

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ object DecisionTree extends Serializable with Logging {
954954
featureIndex += 1
955955
}
956956
} else { // Regression
957-
var featureIndex = 0
957+
var featureIndex = 0
958958
while (featureIndex < numFeatures) {
959959
findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex)
960960
featureIndex += 1

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,35 @@ trait Impurity extends Serializable {
4848
def calculate(count: Double, sum: Double, sumSquares: Double): Double
4949
}
5050

51-
51+
/**
52+
* This class holds a set of sufficient statistics for computing impurity from a sample.
53+
* @param statsSize Length of the vector of sufficient statistics.
54+
*/
5255
private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializable {
5356

57+
/**
58+
* Sufficient statistics for calculating impurity.
59+
*/
5460
var counts: Array[Double] = new Array[Double](statsSize)
5561

5662
def copy: ImpurityAggregator
5763

64+
/**
65+
* Add the given label to this aggregator.
66+
*/
5867
def add(label: Double): Unit
5968

69+
/**
70+
* Compute the impurity for the samples given so far.
71+
* If no samples have been collected, return 0.
72+
*/
6073
def calculate(): Double
6174

75+
/**
76+
* Merge another aggregator into this one, modifying this aggregator.
77+
* @param other Aggregator of the same type.
78+
* @return merged aggregator
79+
*/
6280
def merge(other: ImpurityAggregator): ImpurityAggregator = {
6381
require(counts.size == other.counts.size,
6482
s"Two ImpurityAggregator instances cannot be merged with different counts sizes." +
@@ -71,14 +89,31 @@ private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializ
7189
this
7290
}
7391

92+
/**
93+
* Number of samples added to this aggregator.
94+
*/
7495
def count: Long
7596

97+
/**
98+
* Create a new (empty) aggregator of the same type as this one.
99+
*/
76100
def newAggregator: ImpurityAggregator
77101

102+
/**
103+
* Return the prediction corresponding to the set of labels given to this aggregator.
104+
*/
78105
def predict: Double
79106

107+
/**
108+
* Return the probability of the prediction returned by [[predict]],
109+
* or -1 if no probability is available.
110+
*/
80111
def prob(label: Double): Double = -1
81112

113+
/**
114+
* Return the index of the largest element in this array.
115+
* If there are ties, the first maximal element is chosen.
116+
*/
82117
protected def indexOfLargestArrayElement(array: Array[Double]): Int = {
83118
val result = array.foldLeft(-1, Double.MinValue, 0) {
84119
case ((maxIndex, maxValue, currentIndex), currentValue) =>

0 commit comments

Comments
 (0)