Skip to content

Commit 62dc723

Browse files
committed
updating javadoc and converting helper methods to package private to allow unit testing
1 parent 201702f commit 62dc723

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
172172
object DecisionTree extends Serializable with Logging {
173173

174174
/**
175-
* Method to train a decision tree model over an RDD
175+
* Method to train a decision tree model where the instances are represented as an RDD of
176+
* (label, features) pairs. The method supports binary classification and regression. For the
177+
* binary classification, the label for each instance should either be 0 or 1 to denote the two
178+
* classes. The parameters for the algorithm are specified using the strategy parameter.
179+
*
176180
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
177181
* for DecisionTree
178182
* @param strategy The configuration parameters for the tree algorithm which specify the type
@@ -185,7 +189,11 @@ object DecisionTree extends Serializable with Logging {
185189
}
186190

187191
/**
188-
* Method to train a decision tree model over an RDD
192+
* Method to train a decision tree model where the instances are represented as an RDD of
193+
* (label, features) pairs. The method supports binary classification and regression. For the
194+
* binary classification, the label for each instance should either be 0 or 1 to denote the two
195+
* classes.
196+
*
189197
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
190198
* training data
191199
* @param algo algorithm, classification or regression
@@ -204,8 +212,13 @@ object DecisionTree extends Serializable with Logging {
204212

205213

206214
/**
207-
* Method to train a decision tree model over an RDD
208-
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
215+
* Method to train a decision tree model where the instances are represented as an RDD of
216+
* (label, features) pairs. The decision tree method supports binary classification and
217+
* regression. For the binary classification, the label for each instance should either be 0 or
218+
* 1 to denote the two classes. The method also supports categorical features inputs where the
219+
* number of categories can specified using the categoricalFeaturesInfo option.
220+
*
221+
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
209222
* training data for DecisionTree
210223
* @param algo classification or regression
211224
* @param impurity criterion used for information gain calculation
@@ -236,6 +249,7 @@ object DecisionTree extends Serializable with Logging {
236249

237250
/**
238251
* Returns an array of optimal splits for all nodes at a given level
252+
*
239253
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
240254
* for DecisionTree
241255
* @param parentImpurities Impurities for all parent nodes for the current level
@@ -247,7 +261,7 @@ object DecisionTree extends Serializable with Logging {
247261
* @param bins possible bins for all features
248262
* @return array of splits with best splits for all nodes at a given level.
249263
*/
250-
private def findBestSplits(
264+
protected[tree] def findBestSplits(
251265
input: RDD[LabeledPoint],
252266
parentImpurities: Array[Double],
253267
strategy: Strategy,
@@ -885,7 +899,7 @@ object DecisionTree extends Serializable with Logging {
885899
* .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
886900
* .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
887901
*/
888-
private def findSplitsBins(
902+
protected[tree] def findSplitsBins(
889903
input: RDD[LabeledPoint],
890904
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
891905
val count = input.count()

0 commit comments

Comments
 (0)