@@ -48,17 +48,35 @@ trait Impurity extends Serializable {
48
48
def calculate (count : Double , sum : Double , sumSquares : Double ): Double
49
49
}
50
50
51
-
51
+ /**
52
+ * This class holds a set of sufficient statistics for computing impurity from a sample.
53
+ * @param statsSize Length of the vector of sufficient statistics.
54
+ */
52
55
private [tree] abstract class ImpurityAggregator (statsSize : Int ) extends Serializable {
53
56
57
+ /**
58
+ * Sufficient statistics for calculating impurity.
59
+ */
54
60
var counts : Array [Double ] = new Array [Double ](statsSize)
55
61
56
62
def copy : ImpurityAggregator
57
63
64
+ /**
65
+ * Add the given label to this aggregator.
66
+ */
58
67
def add (label : Double ): Unit
59
68
69
+ /**
70
+ * Compute the impurity for the samples given so far.
71
+ * If no samples have been collected, return 0.
72
+ */
60
73
def calculate (): Double
61
74
75
+ /**
76
+ * Merge another aggregator into this one, modifying this aggregator.
77
+ * @param other Aggregator of the same type.
78
+ * @return merged aggregator
79
+ */
62
80
def merge (other : ImpurityAggregator ): ImpurityAggregator = {
63
81
require(counts.size == other.counts.size,
64
82
s " Two ImpurityAggregator instances cannot be merged with different counts sizes. " +
@@ -71,14 +89,31 @@ private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializ
71
89
this
72
90
}
73
91
92
+ /**
93
+ * Number of samples added to this aggregator.
94
+ */
74
95
def count : Long
75
96
97
+ /**
98
+ * Create a new (empty) aggregator of the same type as this one.
99
+ */
76
100
def newAggregator : ImpurityAggregator
77
101
102
+ /**
103
+ * Return the prediction corresponding to the set of labels given to this aggregator.
104
+ */
78
105
def predict : Double
79
106
107
+ /**
108
+ * Return the probability of the prediction returned by [[predict ]],
109
+ * or -1 if no probability is available.
110
+ */
80
111
def prob (label : Double ): Double = - 1
81
112
113
+ /**
114
+ * Return the index of the largest element in this array.
115
+ * If there are ties, the first maximal element is chosen.
116
+ */
82
117
protected def indexOfLargestArrayElement (array : Array [Double ]): Int = {
83
118
val result = array.foldLeft(- 1 , Double .MinValue , 0 ) {
84
119
case ((maxIndex, maxValue, currentIndex), currentValue) =>
0 commit comments