Skip to content

Commit bb9f610

Browse files
committed
Small cleanups after original tree API PR
1 parent a83571a commit bb9f610

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala

+12-6
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ object DecisionTreeExample {
7070
val parser = new OptionParser[Params]("DecisionTreeExample") {
7171
head("DecisionTreeExample: an example decision tree app.")
7272
opt[String]("algo")
73-
.text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
73+
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
7474
.action((x, c) => c.copy(algo = x))
7575
opt[Int]("maxDepth")
7676
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
@@ -221,18 +221,23 @@ object DecisionTreeExample {
221221
// (1) For classification, re-index classes.
222222
val labelColName = if (algo == "classification") "indexedLabel" else "label"
223223
if (algo == "classification") {
224-
val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
224+
val labelIndexer = new StringIndexer()
225+
.setInputCol("labelString")
226+
.setOutputCol(labelColName)
225227
stages += labelIndexer
226228
}
227229
// (2) Identify categorical features using VectorIndexer.
228230
// Features with more than maxCategories values will be treated as continuous.
229-
val featuresIndexer = new VectorIndexer().setInputCol("features")
230-
.setOutputCol("indexedFeatures").setMaxCategories(10)
231+
val featuresIndexer = new VectorIndexer()
232+
.setInputCol("features")
233+
.setOutputCol("indexedFeatures")
234+
.setMaxCategories(10)
231235
stages += featuresIndexer
232236
// (3) Learn DecisionTree
233237
val dt = algo match {
234238
case "classification" =>
235-
new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
239+
new DecisionTreeClassifier()
240+
.setFeaturesCol("indexedFeatures")
236241
.setLabelCol(labelColName)
237242
.setMaxDepth(params.maxDepth)
238243
.setMaxBins(params.maxBins)
@@ -241,7 +246,8 @@ object DecisionTreeExample {
241246
.setCacheNodeIds(params.cacheNodeIds)
242247
.setCheckpointInterval(params.checkpointInterval)
243248
case "regression" =>
244-
new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
249+
new DecisionTreeRegressor()
250+
.setFeaturesCol("indexedFeatures")
245251
.setLabelCol(labelColName)
246252
.setMaxDepth(params.maxDepth)
247253
.setMaxBins(params.maxBins)

mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
283283
def getImpurity: String = getOrDefault(impurity)
284284

285285
/** Convert new impurity to old impurity. */
286-
protected def getOldImpurity: OldImpurity = {
286+
private[ml] def getOldImpurity: OldImpurity = {
287287
getImpurity match {
288288
case "variance" => OldVariance
289289
case _ =>

mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
3838
private[tree] def toOld: OldSplit
3939
}
4040

41-
private[ml] object Split {
41+
private[tree] object Split {
4242

4343
def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
4444
oldSplit.featureType match {
@@ -58,7 +58,7 @@ private[ml] object Split {
5858
* left. Otherwise, it goes right.
5959
* @param numCategories Number of categories for this feature.
6060
*/
61-
final class CategoricalSplit(
61+
final class CategoricalSplit private[ml] (
6262
override val featureIndex: Int,
6363
leftCategories: Array[Double],
6464
private val numCategories: Int)
@@ -130,7 +130,8 @@ final class CategoricalSplit(
130130
* @param threshold If the feature value is <= this threshold, then the split goes left.
131131
* Otherwise, it goes right.
132132
*/
133-
final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
133+
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
134+
extends Split {
134135

135136
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
136137
features(featureIndex) <= threshold

0 commit comments

Comments
 (0)