Skip to content

[SPARK-11938][ML] Expose numFeatures in all ML PredictionModel for Py… #9936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,17 @@ class Model(Transformer):
"""

__metaclass__ = ABCMeta


class HasNumFeaturesModel:
"""
Provides getter of the number of features especially for model class
It should be mixin with JavaModel.
"""
@property
@since("1.7.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove since tag. If this is added to another class in a later version, the since tag will be incorrect.

def numFeatures(self):
"""
The number of features used to train the model.
"""
return self._call_java("numFeatures")
25 changes: 19 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings

from pyspark import since
from pyspark.ml.base import *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only import the needed class

from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param import TypeConverters
Expand Down Expand Up @@ -200,7 +201,7 @@ def _checkThresholdConsistency(self):
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))


class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class LogisticRegressionModel(HasNumFeaturesModel, JavaModel, JavaMLWritable, JavaMLReadable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put HasNumFeaturesModel at end of list (here and elsewhere)

"""
Model fitted by LogisticRegression.

Expand Down Expand Up @@ -324,6 +325,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> model2 = DecisionTreeClassificationModel.load(model_path)
>>> model.featureImportances == model2.featureImportances
True
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -373,7 +376,8 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
class DecisionTreeClassificationModel(HasNumFeaturesModel, DecisionTreeModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.

Expand Down Expand Up @@ -439,6 +443,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -487,7 +493,7 @@ def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)


class RandomForestClassificationModel(TreeEnsembleModels):
class RandomForestClassificationModel(HasNumFeaturesModel, TreeEnsembleModels):
"""
Model fitted by RandomForestClassifier.

Expand Down Expand Up @@ -540,6 +546,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -604,7 +612,7 @@ def getLossType(self):
return self.getOrDefault(self.lossType)


class GBTClassificationModel(TreeEnsembleModels):
class GBTClassificationModel(HasNumFeaturesModel, TreeEnsembleModels):
"""
Model fitted by GBTClassifier.

Expand Down Expand Up @@ -675,6 +683,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
True
>>> model.theta == model2.theta
True
>>> model.numFeatures
2

.. versionadded:: 1.5.0
"""
Expand Down Expand Up @@ -749,7 +759,7 @@ def getModelType(self):
return self.getOrDefault(self.modelType)


class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
class NaiveBayesModel(HasNumFeaturesModel, JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.

Expand Down Expand Up @@ -817,6 +827,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
True
>>> model.weights == model2.weights
True
>>> model.numFeatures
2

.. versionadded:: 1.6.0
"""
Expand Down Expand Up @@ -894,7 +906,8 @@ def getBlockSize(self):
return self.getOrDefault(self.blockSize)


class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
class MultilayerPerceptronClassificationModel(HasNumFeaturesModel, JavaModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by MultilayerPerceptronClassifier.

Expand Down
18 changes: 14 additions & 4 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings

from pyspark import since
from pyspark.ml.base import HasNumFeaturesModel
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
Expand Down Expand Up @@ -80,6 +81,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
True
>>> model.intercept == model2.intercept
True
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -118,7 +121,7 @@ def _create_model(self, java_model):
return LinearRegressionModel(java_model)


class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
class LinearRegressionModel(HasNumFeaturesModel, JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LinearRegression.

Expand Down Expand Up @@ -425,6 +428,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
True
>>> model.depth == model2.depth
True
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -510,7 +515,8 @@ def __repr__(self):


@inherit_doc
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
class DecisionTreeRegressionModel(HasNumFeaturesModel, DecisionTreeModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by DecisionTreeRegressor.

Expand Down Expand Up @@ -564,6 +570,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
0.5
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -613,7 +621,7 @@ def _create_model(self, java_model):
return RandomForestRegressionModel(java_model)


class RandomForestRegressionModel(TreeEnsembleModels):
class RandomForestRegressionModel(HasNumFeaturesModel, TreeEnsembleModels):
"""
Model fitted by RandomForestRegressor.

Expand Down Expand Up @@ -661,6 +669,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.numFeatures
1

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -725,7 +735,7 @@ def getLossType(self):
return self.getOrDefault(self.lossType)


class GBTRegressionModel(TreeEnsembleModels):
class GBTRegressionModel(HasNumFeaturesModel, TreeEnsembleModels):
"""
Model fitted by GBTRegressor.

Expand Down