Skip to content

Commit abc5a23

Browse files
committed
Parameterizing max memory.
1 parent 50b143a commit abc5a23

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
3131
import org.apache.spark.mllib.tree.model._
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.util.random.XORShiftRandom
34+
import org.apache.spark.util.Utils.memoryStringToMb
3435
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3536

3637
/**
@@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7980
// Calculate level for single group construction
8081

8182
// Max memory usage for aggregates
82-
val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB
83+
val maxMemoryUsage = strategy.maxMemory * 1024 * 1024
8384
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
8485
val numElementsPerNode = {
8586
strategy.algo match {
@@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging {
11581159

11591160
val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
11601161
val maxBins = options.getOrElse('maxBins, "100").toString.toInt
1162+
val maxMemUsage = memoryStringToMb(options.getOrElse('maxMemory, "128m").toString)
11611163

1162-
val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
1164+
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, maxMemory=maxMemUsage)
11631165
val model = DecisionTree.train(trainData, strategy)
11641166

1167+
1168+
11651169
// Load test data.
11661170
val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
11671171

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ class Strategy (
4343
val maxDepth: Int,
4444
val maxBins: Int = 100,
4545
val quantileCalculationStrategy: QuantileStrategy = Sort,
46-
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable
46+
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
47+
val maxMemory: Int = 128) extends Serializable

0 commit comments

Comments
 (0)