@@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
31
31
import org .apache .spark .mllib .tree .model ._
32
32
import org .apache .spark .rdd .RDD
33
33
import org .apache .spark .util .random .XORShiftRandom
34
+ import org .apache .spark .util .Utils .memoryStringToMb
34
35
import org .apache .spark .mllib .linalg .{Vector , Vectors }
35
36
36
37
/**
@@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
79
80
// Calculate level for single group construction
80
81
81
82
// Max memory usage for aggregates
82
- val maxMemoryUsage = scala.math.pow( 2 , 27 ).toInt // 128MB
83
+ val maxMemoryUsage = strategy.maxMemory * 1024 * 1024
83
84
logDebug(" max memory usage for aggregates = " + maxMemoryUsage)
84
85
val numElementsPerNode = {
85
86
strategy.algo match {
@@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging {
1158
1159
1159
1160
val maxDepth = options.getOrElse(' maxDepth , " 1" ).toString.toInt
1160
1161
val maxBins = options.getOrElse(' maxBins , " 100" ).toString.toInt
1162
+ val maxMemUsage = memoryStringToMb(options.getOrElse(' maxMemory , " 128m" ).toString)
1161
1163
1162
- val strategy = new Strategy (algo, impurity, maxDepth, maxBins)
1164
+ val strategy = new Strategy (algo, impurity, maxDepth, maxBins, maxMemory = maxMemUsage )
1163
1165
val model = DecisionTree .train(trainData, strategy)
1164
1166
1167
+
1168
+
1165
1169
// Load test data.
1166
1170
val testData = loadLabeledData(sc, options.get(' testDataDir ).get.toString)
1167
1171
0 commit comments