Skip to content

Commit

Permalink
[SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble missing …
Browse files Browse the repository at this point in the history
…methods

## What changes were proposed in this pull request?

Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add `toDebugString` to `DecisionTreeModel`

## How was this patch tested?

Extended doc tests.

Author: Holden Karau <holden@us.ibm.com>

Closes apache#12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods.
  • Loading branch information
holdenk authored and Nick Pentreath committed Jun 2, 2016
1 parent d109a1b commit 7235331
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
20 changes: 20 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
1
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
Expand Down Expand Up @@ -650,6 +652,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.trees
[DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
>>> rfc_path = temp_path + "/rfc"
>>> rf.save(rfc_path)
>>> rf2 = RandomForestClassifier.load(rfc_path)
Expand Down Expand Up @@ -730,6 +734,12 @@ def featureImportances(self):
"""
return self._call_java("featureImportances")

@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]


@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
Expand Down Expand Up @@ -772,6 +782,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.totalNumNodes
15
>>> print(model.toDebugString)
GBTClassificationModel (uid=...)...with 5 trees...
>>> gbtc_path = temp_path + "gbtc"
>>> gbt.save(gbtc_path)
>>> gbt2 = GBTClassifier.load(gbtc_path)
Expand Down Expand Up @@ -869,6 +883,12 @@ def featureImportances(self):
"""
return self._call_java("featureImportances")

@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]


@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
Expand Down
48 changes: 47 additions & 1 deletion python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams):
featureSubsetStrategy = \
Param(Params._dummy(), "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
"options: " + ", ".join(supportedFeatureSubsetStrategies),
"options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].",
typeConverter=TypeConverters.toString)

def __init__(self):
Expand Down Expand Up @@ -744,6 +744,12 @@ def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")

@property
@since("2.0.0")
def toDebugString(self):
"""Full description of model."""
return self._call_java("toDebugString")

def __repr__(self):
return self._call_java("toString")

Expand All @@ -758,12 +764,36 @@ class TreeEnsembleModels(JavaModel):
.. versionadded:: 1.5.0
"""

@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]

@property
@since("2.0.0")
def getNumTrees(self):
"""Number of trees in ensemble."""
return self._call_java("getNumTrees")

@property
@since("1.5.0")
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))

@property
@since("2.0.0")
def totalNumNodes(self):
"""Total number of nodes, summed over all trees in the ensemble."""
return self._call_java("totalNumNodes")

@property
@since("2.0.0")
def toDebugString(self):
"""Full description of model."""
return self._call_java("toDebugString")

def __repr__(self):
return self._call_java("toString")

Expand Down Expand Up @@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
>>> model.getNumTrees
2
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
0.5
Expand Down Expand Up @@ -896,6 +930,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead
.. versionadded:: 1.4.0
"""

@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]

@property
@since("2.0.0")
def featureImportances(self):
Expand Down Expand Up @@ -1045,6 +1085,12 @@ def featureImportances(self):
"""
return self._call_java("featureImportances")

@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]


@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
Expand Down

0 comments on commit 7235331

Please sign in to comment.