1717
1818package org .apache .spark .mllib .tree
1919
20+ import org .apache .spark .mllib .point .PointConverter ._
2021import org .apache .spark .annotation .Experimental
2122import org .apache .spark .Logging
2223import org .apache .spark .mllib .regression .LabeledPoint
@@ -211,9 +212,7 @@ object DecisionTree extends Serializable with Logging {
211212 * @return a DecisionTreeModel that can be used for prediction
212213 */
213214 def train (input : RDD [LabeledPoint ], strategy : Strategy ): DecisionTreeModel = {
214- // Converting from standard instance format to weighted input format for tree training
215- val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
216- new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
215+ new DecisionTree (strategy).train(input)
217216 }
218217
219218 /**
@@ -235,9 +234,7 @@ object DecisionTree extends Serializable with Logging {
235234 impurity : Impurity ,
236235 maxDepth : Int ): DecisionTreeModel = {
237236 val strategy = new Strategy (algo, impurity, maxDepth)
238- // Converting from standard instance format to weighted input format for tree training
239- val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
240- new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
237+ new DecisionTree (strategy).train(input)
241238 }
242239
243240 /**
@@ -261,9 +258,7 @@ object DecisionTree extends Serializable with Logging {
261258 maxDepth : Int ,
262259 numClassesForClassification : Int ): DecisionTreeModel = {
263260 val strategy = new Strategy (algo, impurity, maxDepth, numClassesForClassification)
264- // Converting from standard instance format to weighted input format for tree training
265- val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
266- new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
261+ new DecisionTree (strategy).train(input)
267262 }
268263
269264
@@ -294,9 +289,7 @@ object DecisionTree extends Serializable with Logging {
294289 labelWeights : Map [Int ,Int ]): DecisionTreeModel = {
295290 val strategy = new Strategy (algo, impurity, maxDepth, numClassesForClassification,
296291 labelWeights = labelWeights)
297- // Converting from standard instance format to weighted input format for tree training
298- val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
299- new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
292+ new DecisionTree (strategy).train(input)
300293 }
301294
302295 /**
@@ -337,9 +330,7 @@ object DecisionTree extends Serializable with Logging {
337330 categoricalFeaturesInfo : Map [Int ,Int ]): DecisionTreeModel = {
338331 val strategy = new Strategy (algo, impurity, maxDepth, numClassesForClassification, maxBins,
339332 quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
340- // Converting from standard instance format to weighted input format for tree training
341- val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
342- new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
333+ new DecisionTree (strategy).train(input)
343334 }
344335
345336 private val InvalidBinIndex = - 1
0 commit comments