Skip to content

Commit 84f85d6

Browse files
committed
code documentation
Signed-off-by: Manish Amde <manish9ue@gmail.com>
1 parent 9372779 commit 84f85d6

File tree

14 files changed

+157
-21
lines changed

14 files changed

+157
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2929
import org.apache.spark.mllib.tree.configuration.FeatureType._
3030
import org.apache.spark.mllib.tree.configuration.Algo._
3131

32-
/*
32+
/**
3333
A class that implements a decision tree algorithm for classification and regression.
3434
It supports both continuous and categorical features.
3535
@@ -40,7 +40,7 @@ quantile calculation strategy, etc.
4040
*/
4141
class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
4242

43-
/*
43+
/**
4444
Method to train a decision tree model over an RDD
4545
4646
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
@@ -157,14 +157,14 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
157157

158158
object DecisionTree extends Serializable with Logging {
159159

160-
/*
160+
/**
161161
Returns an Array[Split] of optimal splits for all nodes at a given level
162162
163163
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
164-
for DecisionTree
164+
for DecisionTree
165165
@param parentImpurities Impurities for all parent nodes for the current level
166166
@param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
167-
parameters for construction the DecisionTree
167+
parameters for construction the DecisionTree
168168
@param level Level of the tree
169169
@param filters Filter for all nodes at a given level
170170
@param splits possible splits for all features
@@ -200,7 +200,7 @@ object DecisionTree extends Serializable with Logging {
200200
}
201201
}
202202

203-
/*
203+
/**
204204
Find whether the sample is valid input for the current node.
205205
In other words, does it pass through all the filters for the current node.
206206
*/
@@ -236,7 +236,9 @@ object DecisionTree extends Serializable with Logging {
236236
true
237237
}
238238

239-
/*Finds the right bin for the given feature*/
239+
/**
240+
Finds the right bin for the given feature
241+
*/
240242
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {
241243

242244
if (isFeatureContinuous){
@@ -266,7 +268,8 @@ object DecisionTree extends Serializable with Logging {
266268

267269
}
268270

269-
/*Finds bins for all nodes (and all features) at a given level
271+
/**
272+
Finds bins for all nodes (and all features) at a given level
270273
k features, l nodes (level = log2(l))
271274
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
272275
Denotes invalid sample for tree by noting bin for feature 1 as -1
@@ -343,7 +346,8 @@ object DecisionTree extends Serializable with Logging {
343346
}
344347
}
345348

346-
/*Performs a sequential aggregation over a partition.
349+
/**
350+
Performs a sequential aggregation over a partition.
347351
348352
for p bins, k features, l nodes (level = log2(l)) storage is of the form:
349353
b111_left_count,b111_right_count, .... , ..
@@ -370,7 +374,8 @@ object DecisionTree extends Serializable with Logging {
370374
}
371375
logDebug("binAggregateLength = " + binAggregateLength)
372376

373-
/*Combines the aggregates from partitions
377+
/**
378+
Combines the aggregates from partitions
374379
@param agg1 Array containing aggregates from one or more partitions
375380
@param agg2 Array containing aggregates from one or more partitions
376381
@@ -507,7 +512,7 @@ object DecisionTree extends Serializable with Logging {
507512
}
508513
}
509514

510-
/*
515+
/**
511516
Extracts left and right split aggregates
512517
513518
@param binData Array[Double] of size 2*numFeatures*numSplits
@@ -604,7 +609,7 @@ object DecisionTree extends Serializable with Logging {
604609
gains
605610
}
606611

607-
/*
612+
/**
608613
Find the best split for a node given bin aggregate data
609614
610615
@param binData Array[Double] of size 2*numSplits*numFeatures
@@ -669,7 +674,7 @@ object DecisionTree extends Serializable with Logging {
669674
bestSplits
670675
}
671676

672-
/*
677+
/**
673678
Returns split and bins for decision tree calculation.
674679
675680
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717
package org.apache.spark.mllib.tree.configuration
1818

19+
/**
20+
* Enum to select the algorithm for the decision tree
21+
*/
1922
object Algo extends Enumeration {
2023
type Algo = Value
2124
val Classification, Regression = Value

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717
package org.apache.spark.mllib.tree.configuration
1818

19+
/**
20+
* Enum to describe whether a feature is "continuous" or "categorical"
21+
*/
1922
object FeatureType extends Enumeration {
2023
type FeatureType = Value
2124
val Continuous, Categorical = Value

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717
package org.apache.spark.mllib.tree.configuration
1818

19+
/**
20+
* Enum for selecting the quantile calculation strategy
21+
*/
1922
object QuantileStrategy extends Enumeration {
2023
type QuantileStrategy = Value
2124
val Sort, MinMax, ApproxHist = Value

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ import org.apache.spark.mllib.tree.impurity.Impurity
2020
import org.apache.spark.mllib.tree.configuration.Algo._
2121
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2222

23+
/**
24+
* Stores all the configuration options for tree construction
25+
* @param algo classification or regression
26+
* @param impurity criterion used for information gain calculation
27+
* @param maxDepth maximum depth of the tree
28+
* @param maxBins maximum number of bins used for splitting features
29+
* @param quantileCalculationStrategy algorithm for calculating quantiles
30+
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
31+
* number of discrete values they take. For example, an entry (n ->
32+
* k) implies the feature n is categorical with k categories 0,
33+
* 1, 2, ... , k-1. It's important to note that features are
34+
* zero-indexed.
35+
*/
2336
class Strategy (
2437
val algo : Algo,
2538
val impurity : Impurity,

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,20 @@ package org.apache.spark.mllib.tree.impurity
1818

1919
import javax.naming.OperationNotSupportedException
2020

21+
/**
22+
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
23+
* binary classification.
24+
*/
2125
object Entropy extends Impurity {
2226

2327
def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
2428

29+
/**
30+
* entropy calculation
31+
* @param c0 count of instances with label 0
32+
* @param c1 count of instances with label 1
33+
* @return entropy value
34+
*/
2535
def calculate(c0: Double, c1: Double): Double = {
2636
if (c0 == 0 || c1 == 0) {
2737
0

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,18 @@ package org.apache.spark.mllib.tree.impurity
1818

1919
import javax.naming.OperationNotSupportedException
2020

21+
/**
22+
* Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini
23+
* coefficent]] during binary classification
24+
*/
2125
object Gini extends Impurity {
2226

27+
/**
28+
* gini coefficient calculation
29+
* @param c0 count of instances with label 0
30+
* @param c1 count of instances with label 1
31+
* @return gini coefficient value
32+
*/
2333
def calculate(c0 : Double, c1 : Double): Double = {
2434
if (c0 == 0 || c1 == 0) {
2535
0

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,19 @@ package org.apache.spark.mllib.tree.impurity
1919
import javax.naming.OperationNotSupportedException
2020
import org.apache.spark.Logging
2121

22+
/**
23+
* Class for calculating variance during regression
24+
*/
2225
object Variance extends Impurity with Logging {
2326
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")
2427

28+
/**
29+
* variance calculation
30+
* @param count number of instances
31+
* @param sum sum of labels
32+
* @param sumSquares summation of squares of the labels
33+
* @return
34+
*/
2535
def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
2636
val squaredLoss = sumSquares - (sum*sum)/count
2737
squaredLoss/count

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ package org.apache.spark.mllib.tree.model
1818

1919
import org.apache.spark.mllib.tree.configuration.FeatureType._
2020

21+
/**
22+
* Used for "binning" the features bins for faster best split calculation. For a continuous
23+
* feature, a bin is determined by a low and a high "split". For a categorical feature,
24+
* the a bin is determined using a single label value (category).
25+
* @param lowSplit signifying the lower threshold for the continuous feature to be
26+
* accepted in the bin
27+
* @param highSplit signifying the upper threshold for the continuous feature to be
28+
* accepted in the bin
29+
* @param featureType type of feature -- categorical or continuous
30+
* @param category categorical label value accepted in the bin
31+
*/
2132
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) {
2233

2334
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,23 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19-
import org.apache.spark.mllib.regression.LabeledPoint
2019
import org.apache.spark.mllib.tree.configuration.Algo._
20+
import org.apache.spark.rdd.RDD
2121

22+
/**
23+
* Model to store the decision tree parameters
24+
* @param topNode root node
25+
* @param algo algorithm type -- classification or regression
26+
*/
2227
class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable {
2328

24-
def predict(features : Array[Double]) = {
29+
/**
30+
* Predict values for a single data point using the model trained.
31+
*
32+
* @param features array representing a single data point
33+
* @return Double prediction from the trained model
34+
*/
35+
def predict(features : Array[Double]) : Double = {
2536
algo match {
2637
case Classification => {
2738
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
@@ -32,4 +43,15 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl
3243
}
3344
}
3445

46+
/**
47+
* Predict values for the given data set using the model trained.
48+
*
49+
* @param features RDD representing data points to be predicted
50+
* @return RDD[Int] where each entry contains the corresponding prediction
51+
*/
52+
def predict(features: RDD[Array[Double]]): RDD[Double] = {
53+
features.map(x => predict(x))
54+
}
55+
56+
3557
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19+
/**
20+
* Filter specifying a split and type of comparison to be applied on features
21+
* @param split split specifying the feature index, type and threshold
22+
* @param comparison integer specifying <,=,>
23+
*/
1924
case class Filter(split : Split, comparison : Int) {
2025
// Comparison -1,0,1 signifies <.=,>
2126
override def toString = " split = " + split + "comparison = " + comparison

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
*/
1717
package org.apache.spark.mllib.tree.model
1818

19+
/**
20+
* Information gain statistics for each split
21+
* @param gain information gain value
22+
* @param impurity current node impurity
23+
* @param leftImpurity left node impurity
24+
* @param rightImpurity right node impurity
25+
* @param predict predicted value
26+
*/
1927
class InformationGainStats(
2028
val gain : Double,
2129
val impurity: Double,

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ import org.apache.spark.Logging
2020
import org.apache.spark.mllib.regression.LabeledPoint
2121
import org.apache.spark.mllib.tree.configuration.FeatureType._
2222

23+
/**
24+
* Node in a decision tree
25+
* @param id integer node id
26+
* @param predict predicted value at the node
27+
* @param isLeaf whether the leaf is a node
28+
* @param split split to calculate left and right nodes
29+
* @param leftNode left child
30+
* @param rightNode right child
31+
* @param stats information gain stats
32+
*/
2333
class Node ( val id : Int,
2434
val predict : Double,
2535
val isLeaf : Boolean,

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ package org.apache.spark.mllib.tree.model
1818

1919
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
2020

21+
/**
22+
* Split applied to a feature
23+
* @param feature feature index
24+
* @param threshold threshold for continuous feature
25+
* @param featureType type of feature -- categorical or continuous
26+
* @param categories accepted values for categorical variables
27+
*/
2128
case class Split(
2229
feature: Int,
2330
threshold : Double,
@@ -29,12 +36,28 @@ case class Split(
2936
", categories = " + categories
3037
}
3138

32-
class DummyLowSplit(feature: Int, kind : FeatureType)
33-
extends Split(feature, Double.MinValue, kind, List())
39+
/**
40+
* Split with minimum threshold for continuous features. Helps with the smallest bin creation.
41+
* @param feature feature index
42+
* @param featureType type of feature -- categorical or continuous
43+
*/
44+
class DummyLowSplit(feature: Int, featureType : FeatureType)
45+
extends Split(feature, Double.MinValue, featureType, List())
3446

35-
class DummyHighSplit(feature: Int, kind : FeatureType)
36-
extends Split(feature, Double.MaxValue, kind, List())
47+
/**
48+
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
49+
* @param feature feature index
50+
* @param featureType type of feature -- categorical or continuous
51+
*/
52+
class DummyHighSplit(feature: Int, featureType : FeatureType)
53+
extends Split(feature, Double.MaxValue, featureType, List())
3754

38-
class DummyCategoricalSplit(feature: Int, kind : FeatureType)
39-
extends Split(feature, Double.MaxValue, kind, List())
55+
/**
56+
* Split with no acceptable feature values for categorical features. Helps with the first bin
57+
* creation.
58+
* @param feature feature index
59+
* @param featureType type of feature -- categorical or continuous
60+
*/
61+
class DummyCategoricalSplit(feature: Int, featureType : FeatureType)
62+
extends Split(feature, Double.MaxValue, featureType, List())
4063

0 commit comments

Comments
 (0)