Skip to content

Commit eaf84c0

Browse files
committed
Added DecisionTree static train() methods API to match Python, but without default parameters
1 parent 41e0a21 commit eaf84c0

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,30 @@ object DecisionTree extends Serializable with Logging {
300300
new DecisionTree(strategy).train(input)
301301
}
302302

303+
def train(
304+
input: RDD[LabeledPoint],
305+
algo: Algo,
306+
numClassesForClassification: Int,
307+
categoricalFeaturesInfo: Map[Int,Int],
308+
impurity: Impurity,
309+
maxDepth: Int,
310+
maxBins: Int): DecisionTreeModel = ???
311+
312+
def trainClassifier(
313+
input: RDD[LabeledPoint],
314+
numClassesForClassification: Int,
315+
categoricalFeaturesInfo: Map[Int,Int],
316+
impurity: Impurity,
317+
maxDepth: Int,
318+
maxBins: Int): DecisionTreeModel = ???
319+
320+
def trainRegressor(
321+
input: RDD[LabeledPoint],
322+
categoricalFeaturesInfo: Map[Int,Int],
323+
impurity: Impurity,
324+
maxDepth: Int,
325+
maxBins: Int): DecisionTreeModel = ???
326+
303327
private val InvalidBinIndex = -1
304328

305329
/**

0 commit comments

Comments
 (0)