Skip to content

Commit 02c595c

Browse files
committed
added command line parsing
Signed-off-by: Manish Amde <manish9ue@gmail.com>
1 parent 98ec8d5 commit 02c595c

File tree

3 files changed

+52
-29
lines changed

3 files changed

+52
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala

Lines changed: 0 additions & 21 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ object DecisionTree extends Serializable with Logging {
285285

286286
/*Combines the aggregates from partitions
287287
@param agg1 Array containing aggregates from one or more partitions
288-
@param agg2 Array contianing aggregates from one or more partitions
288+
@param agg2 Array containing aggregates from one or more partitions
289289
290290
@return Combined aggregate from agg1 and agg2
291291
*/

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,69 @@
1717
package org.apache.spark.mllib.tree
1818

1919
import org.apache.spark.{Logging, SparkContext}
20-
import org.apache.spark.mllib.tree.impurity.Gini
20+
import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance}
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.mllib.regression.LabeledPoint
2323
import org.apache.spark.mllib.tree.model.DecisionTreeModel
2424

2525
object DecisionTreeRunner extends Logging {
2626

27+
val usage = """
28+
Usage: DecisionTreeRunner <master>[slices] --kind <Classification,Regression> --trainDataDir path --testDataDir path [--maxDepth num] [--impurity <Gini,Entropy,Variance>] [--maxBins num]
29+
"""
30+
2731

2832
def main(args: Array[String]) {
2933

34+
if (args.length < 2) {
35+
System.err.println(usage)
36+
System.exit(1)
37+
}
38+
3039
val sc = new SparkContext(args(0), "DecisionTree")
31-
val data = loadLabeledData(sc, args(1))
32-
val maxDepth = args(2).toInt
33-
val maxBins = args(3).toInt
3440

35-
val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins)
36-
val model = new DecisionTree(strategy).train(data)
3741

38-
val accuracy = accuracyScore(model, data)
42+
val arglist = args.toList.drop(1)
43+
type OptionMap = Map[Symbol, Any]
44+
45+
def nextOption(map : OptionMap, list: List[String]) : OptionMap = {
46+
def isSwitch(s : String) = (s(0) == '-')
47+
list match {
48+
case Nil => map
49+
case "--kind" :: string :: tail => nextOption(map ++ Map('kind -> string), tail)
50+
case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
51+
case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
52+
case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
53+
case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail)
54+
case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail)
55+
case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
56+
case option :: tail => println("Unknown option "+option)
57+
exit(1)
58+
}
59+
}
60+
val options = nextOption(Map(),arglist)
61+
logDebug(options.toString())
62+
63+
val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
64+
65+
val typeStr = options.get('type).toString
66+
//TODO: Create enum
67+
val impurityStr = options.getOrElse('impurity,if (typeStr == "classification") "Gini" else "Variance").toString
68+
val impurity = {
69+
impurityStr match {
70+
case "Gini" => Gini
71+
case "Entropy" => Entropy
72+
case "Variance" => Variance
73+
}
74+
}
75+
val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt
76+
val maxBins = options.getOrElse('maxBins,"100").toString.toInt
77+
78+
val strategy = new Strategy(kind = typeStr, impurity = Gini, maxDepth = maxDepth, maxBins = maxBins)
79+
val model = new DecisionTree(strategy).train(trainData)
80+
81+
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
82+
val accuracy = accuracyScore(model, testData)
3983
logDebug("accuracy = " + accuracy)
4084

4185
sc.stop()

0 commit comments

Comments
 (0)