@@ -172,7 +172,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
172
172
object DecisionTree extends Serializable with Logging {
173
173
174
174
/**
175
- * Method to train a decision tree model over an RDD
175
+ * Method to train a decision tree model where the instances are represented as an RDD of
176
+ * (label, features) pairs. The method supports binary classification and regression. For the
177
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
178
+ * classes. The parameters for the algorithm are specified using the strategy parameter.
179
+ *
176
180
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
177
181
* for DecisionTree
178
182
* @param strategy The configuration parameters for the tree algorithm which specify the type
@@ -185,7 +189,11 @@ object DecisionTree extends Serializable with Logging {
185
189
}
186
190
187
191
/**
188
- * Method to train a decision tree model over an RDD
192
+ * Method to train a decision tree model where the instances are represented as an RDD of
193
+ * (label, features) pairs. The method supports binary classification and regression. For the
194
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
195
+ * classes.
196
+ *
189
197
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
190
198
* training data
191
199
* @param algo algorithm, classification or regression
@@ -204,8 +212,13 @@ object DecisionTree extends Serializable with Logging {
204
212
205
213
206
214
/**
207
- * Method to train a decision tree model over an RDD
208
- * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
215
+ * Method to train a decision tree model where the instances are represented as an RDD of
216
+ * (label, features) pairs. The decision tree method supports binary classification and
217
+ * regression. For the binary classification, the label for each instance should either be 0 or
218
+ * 1 to denote the two classes. The method also supports categorical features inputs where the
219
+ * number of categories can specified using the categoricalFeaturesInfo option.
220
+ *
221
+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
209
222
* training data for DecisionTree
210
223
* @param algo classification or regression
211
224
* @param impurity criterion used for information gain calculation
@@ -236,6 +249,7 @@ object DecisionTree extends Serializable with Logging {
236
249
237
250
/**
238
251
* Returns an array of optimal splits for all nodes at a given level
252
+ *
239
253
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
240
254
* for DecisionTree
241
255
* @param parentImpurities Impurities for all parent nodes for the current level
@@ -247,7 +261,7 @@ object DecisionTree extends Serializable with Logging {
247
261
* @param bins possible bins for all features
248
262
* @return array of splits with best splits for all nodes at a given level.
249
263
*/
250
- private def findBestSplits (
264
+ protected [tree] def findBestSplits (
251
265
input : RDD [LabeledPoint ],
252
266
parentImpurities : Array [Double ],
253
267
strategy : Strategy ,
@@ -885,7 +899,7 @@ object DecisionTree extends Serializable with Logging {
885
899
* .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
886
900
* .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
887
901
*/
888
- private def findSplitsBins (
902
+ protected [tree] def findSplitsBins (
889
903
input : RDD [LabeledPoint ],
890
904
strategy : Strategy ): (Array [Array [Split ]], Array [Array [Bin ]]) = {
891
905
val count = input.count()
0 commit comments