Skip to content

Commit 485eaae

Browse files
committed
implicit conversion from LabeledPoint to WeightedLabeledPoint
1 parent 3d7f911 commit 485eaae

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.point
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.mllib.regression.LabeledPoint
22+
23+
object PointConverter {
24+
25+
implicit def LabeledPoint2WeightedLabeledPoint(
26+
points : RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = {
27+
points.map(point => new WeightedLabeledPoint(point.label,point.features))
28+
}
29+
30+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20+
import org.apache.spark.mllib.point.PointConverter._
2021
import org.apache.spark.annotation.Experimental
2122
import org.apache.spark.Logging
2223
import org.apache.spark.mllib.regression.LabeledPoint
@@ -211,9 +212,7 @@ object DecisionTree extends Serializable with Logging {
211212
* @return a DecisionTreeModel that can be used for prediction
212213
*/
213214
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
214-
// Converting from standard instance format to weighted input format for tree training
215-
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
216-
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
215+
new DecisionTree(strategy).train(input)
217216
}
218217

219218
/**
@@ -235,9 +234,7 @@ object DecisionTree extends Serializable with Logging {
235234
impurity: Impurity,
236235
maxDepth: Int): DecisionTreeModel = {
237236
val strategy = new Strategy(algo, impurity, maxDepth)
238-
// Converting from standard instance format to weighted input format for tree training
239-
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
240-
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
237+
new DecisionTree(strategy).train(input)
241238
}
242239

243240
/**
@@ -261,9 +258,7 @@ object DecisionTree extends Serializable with Logging {
261258
maxDepth: Int,
262259
numClassesForClassification: Int): DecisionTreeModel = {
263260
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
264-
// Converting from standard instance format to weighted input format for tree training
265-
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
266-
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
261+
new DecisionTree(strategy).train(input)
267262
}
268263

269264

@@ -294,9 +289,7 @@ object DecisionTree extends Serializable with Logging {
294289
labelWeights: Map[Int,Int]): DecisionTreeModel = {
295290
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification,
296291
labelWeights = labelWeights)
297-
// Converting from standard instance format to weighted input format for tree training
298-
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
299-
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
292+
new DecisionTree(strategy).train(input)
300293
}
301294

302295
/**
@@ -337,9 +330,7 @@ object DecisionTree extends Serializable with Logging {
337330
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
338331
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
339332
quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
340-
// Converting from standard instance format to weighted input format for tree training
341-
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
342-
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
333+
new DecisionTree(strategy).train(input)
343334
}
344335

345336
private val InvalidBinIndex = -1

0 commit comments

Comments
 (0)