|
17 | 17 | package org.apache.spark.mllib.tree
|
18 | 18 |
|
19 | 19 | import org.apache.spark.{Logging, SparkContext}
|
20 |
| -import org.apache.spark.mllib.tree.impurity.Gini |
| 20 | +import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} |
21 | 21 | import org.apache.spark.rdd.RDD
|
22 | 22 | import org.apache.spark.mllib.regression.LabeledPoint
|
23 | 23 | import org.apache.spark.mllib.tree.model.DecisionTreeModel
|
24 | 24 |
|
25 | 25 | object DecisionTreeRunner extends Logging {
|
26 | 26 |
|
| 27 | + val usage = """ |
| 28 | + Usage: DecisionTreeRunner <master>[slices] --kind <Classification,Regression> --trainDataDir path --testDataDir path [--maxDepth num] [--impurity <Gini,Entropy,Variance>] [--maxBins num] |
| 29 | + """ |
| 30 | + |
27 | 31 |
|
28 | 32 | def main(args: Array[String]) {
|
29 | 33 |
|
| 34 | + if (args.length < 2) { |
| 35 | + System.err.println(usage) |
| 36 | + System.exit(1) |
| 37 | + } |
| 38 | + |
30 | 39 | val sc = new SparkContext(args(0), "DecisionTree")
|
31 |
| - val data = loadLabeledData(sc, args(1)) |
32 |
| - val maxDepth = args(2).toInt |
33 |
| - val maxBins = args(3).toInt |
34 | 40 |
|
35 |
| - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) |
36 |
| - val model = new DecisionTree(strategy).train(data) |
37 | 41 |
|
38 |
| - val accuracy = accuracyScore(model, data) |
| 42 | + val arglist = args.toList.drop(1) |
| 43 | + type OptionMap = Map[Symbol, Any] |
| 44 | + |
| 45 | + def nextOption(map : OptionMap, list: List[String]) : OptionMap = { |
| 46 | + def isSwitch(s : String) = (s(0) == '-') |
| 47 | + list match { |
| 48 | + case Nil => map |
| 49 | + case "--kind" :: string :: tail => nextOption(map ++ Map('kind -> string), tail) |
| 50 | + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) |
| 51 | + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) |
| 52 | + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) |
| 53 | + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) |
| 54 | + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) |
| 55 | + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) |
| 56 | + case option :: tail => println("Unknown option "+option) |
| 57 | + exit(1) |
| 58 | + } |
| 59 | + } |
| 60 | + val options = nextOption(Map(),arglist) |
| 61 | + logDebug(options.toString()) |
| 62 | + |
| 63 | + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) |
| 64 | + |
| 65 | + val typeStr = options.get('type).toString |
| 66 | + //TODO: Create enum |
| 67 | + val impurityStr = options.getOrElse('impurity,if (typeStr == "classification") "Gini" else "Variance").toString |
| 68 | + val impurity = { |
| 69 | + impurityStr match { |
| 70 | + case "Gini" => Gini |
| 71 | + case "Entropy" => Entropy |
| 72 | + case "Variance" => Variance |
| 73 | + } |
| 74 | + } |
| 75 | + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt |
| 76 | + val maxBins = options.getOrElse('maxBins,"100").toString.toInt |
| 77 | + |
| 78 | + val strategy = new Strategy(kind = typeStr, impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) |
| 79 | + val model = new DecisionTree(strategy).train(trainData) |
| 80 | + |
| 81 | + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) |
| 82 | + val accuracy = accuracyScore(model, testData) |
39 | 83 | logDebug("accuracy = " + accuracy)
|
40 | 84 |
|
41 | 85 | sc.stop()
|
|
0 commit comments