1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one or more
3
+ * contributor license agreements. See the NOTICE file distributed with
4
+ * this work for additional information regarding copyright ownership.
5
+ * The ASF licenses this file to You under the Apache License, Version 2.0
6
+ * (the "License"); you may not use this file except in compliance with
7
+ * the License. You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ package org .apache .spark .mllib .tree
19
+
20
+ import org .apache .spark .rdd .RDD
21
+ import org .apache .spark .mllib .regression .LabeledPoint
22
+ import org .apache .spark .mllib .tree .model .{Split , Bin , DecisionTreeModel }
23
+
24
+
25
+ class DecisionTree (val strategy : Strategy ) {
26
+
27
+ def train (input : RDD [LabeledPoint ]) : DecisionTreeModel = {
28
+
29
+ // Cache input RDD for speedup during multiple passes
30
+ input.cache()
31
+
32
+ // TODO: Find all splits and bins using quantiles including support for categorical features, single-pass
33
+ val (splits, bins) = DecisionTree .find_splits_bins(input, strategy)
34
+
35
+ // TODO: Level-wise training of tree and obtain Decision Tree model
36
+
37
+
38
+ return new DecisionTreeModel ()
39
+ }
40
+
41
+ }
42
+
43
+ object DecisionTree {
44
+ def find_splits_bins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
45
+ val numSplits = strategy.numSplits
46
+ // TODO: Justify this calculation
47
+ val requiredSamples : Long = numSplits* numSplits
48
+ val count : Long = input.count()
49
+ val numSamples : Long = if (requiredSamples < count) requiredSamples else count
50
+ val numFeatures = input.take(1 )(0 ).features.length
51
+ (Array .ofDim[Split ](numFeatures,numSplits),Array .ofDim[Bin ](numFeatures,numSplits))
52
+ }
53
+
54
+ }
0 commit comments