Skip to content

[SPARK-24333][ML][PYTHON]Add fit with validation set to spark.ml GBT: Python API #21465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

huaxingao
Copy link
Contributor

What changes were proposed in this pull request?

Add validationIndicatorCol and validationTol to GBT Python.

How was this patch tested?

Add test in doctest to test the new API.

@SparkQA
Copy link

SparkQA commented May 31, 2018

Test build #91317 has finished for PR 21465 at commit 79fc83b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class HasValidationIndicatorCol(Params):

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
featureSubsetStrategy="all"):
featureSubsetStrategy="all", validationTol=0.01):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't validationIndicatorCol be in init too? Set to None default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MLnick Yes, I should add it in init. Will change it now. Thanks a lot for your review!

@SparkQA
Copy link

SparkQA commented Jun 8, 2018

Test build #91586 has finished for PR 21465 at commit 4290b58.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 13, 2018

Test build #91798 has finished for PR 21465 at commit 4290b58.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Sep 10, 2018

Test build #95893 has finished for PR 21465 at commit 1169db8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 13, 2018

Test build #98751 has finished for PR 21465 at commit 1169db8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 19, 2018

Test build #99013 has finished for PR 21465 at commit 6e177a3.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@BryanCutler BryanCutler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @huaxingao for the PR! I think the new params should be added in GBTParams. While there, maybe you could add HasMaxIter and HasStepSize also, to match the Scala side.

GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
JavaMLReadable):
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed,
HasValidationIndicatorCol, JavaMLWritable, JavaMLReadable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be added to GBTParams, which is done on the Scala side too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler Thank you very much for reviewing my PR. I moved HasValidationIndicatorCol, HasMaxIter and HasStepSize to GBTParams.

@SparkQA
Copy link

SparkQA commented Nov 21, 2018

Test build #99136 has finished for PR 21465 at commit 88ff888.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, GBTParams,
  • class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
  • class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, GBTParams,

Copy link
Member

@BryanCutler BryanCutler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @huaxingao but lets also add GBTClassifierParams and GBTRegressorParams to handle lossType as is done in Scala.

@@ -705,12 +705,38 @@ def getNumTrees(self):
return self.getOrDefault(self.numTrees)


class GBTParams(TreeEnsembleParams):
class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
Copy link
Member

@BryanCutler BryanCutler Nov 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like having a common GBTParams class, it was strange to have this defined in both estimators. But you should also define GBTClassifierParams and GBTRegressorParams, then put the supportedLossTypes in there so you don't need to override them later. You can also put the lossType param and getLossType() method there. This makes it clean and follows how it's done in Scala.

GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
JavaMLReadable, TreeRegressorParams):
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, GBTParams,
HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, JavaMLReadable,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove HasStepSize since it is in GBTParams

@SparkQA
Copy link

SparkQA commented Nov 28, 2018

Test build #99408 has finished for PR 21465 at commit c0fcbb3.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class GBTClassifierParams(GBTParams, HasVarianceImpurity):
  • class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
  • class HasVarianceImpurity(Params):
  • class TreeRegressorParams(HasVarianceImpurity):
  • class GBTRegressorParams(GBTParams, TreeRegressorParams):
  • class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,

@SparkQA
Copy link

SparkQA commented Nov 28, 2018

Test build #99413 has finished for PR 21465 at commit c0586bd.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@BryanCutler BryanCutler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huaxingao there are quite a lot of deviations from how these classes are in Scala, please follow how the class hierarchy is defined there and it should all fit together.

"Supported options: " + ", ".join(supportedLossTypes),
typeConverter=TypeConverters.toString)

@since("3.0.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't change the version, since we are just refactoring the base classes

Copy link
Member

@BryanCutler BryanCutler Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address the previous comment, to not change the since version since we are just refactoring the base class.

typeConverter=TypeConverters.toString)

@since("3.0.0")
def setLossType(self, value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setLossType should be in the estimators, getLossType should be here

Copy link
Member

@BryanCutler BryanCutler Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address the above comment, this method should be in the estimator

@@ -1174,9 +1165,31 @@ def trees(self):
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]


class GBTClassifierParams(GBTParams, HasVarianceImpurity):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should extend TreeClassifierParams

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler Thanks for your review.
Seems recently #22986 added trait HasVarianceImpurity and made
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see. let me take another look..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're correct, this is fine

@@ -650,19 +650,20 @@ def getFeatureSubsetStrategy(self):
return self.getOrDefault(self.featureSubsetStrategy)


class TreeRegressorParams(Params):
class HasVarianceImpurity(Params):
Copy link
Member

@BryanCutler BryanCutler Dec 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be changed, impurity is different for regression and classification, so the param needs to be defined in TreeRegressorParams and TreeClassifierParams, as it was already This is correct and matches Scala currently

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
featureSubsetStrategy="all"):
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
Copy link
Member

@BryanCutler BryanCutler Dec 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the correct default impurity default value has been changed in Scala, this is correct

Copy link
Member

@BryanCutler BryanCutler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please look at some of my previous comments and fix those, then I think it will be good to go, thanks!

typeConverter=TypeConverters.toFloat)

@since("3.0.0")
def setValidationTol(self, value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems scala does not have this API right? If not then let's remove it here for now

@@ -1174,9 +1165,31 @@ def trees(self):
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]


class GBTClassifierParams(GBTParams, HasVarianceImpurity):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're correct, this is fine

"Supported options: " + ", ".join(supportedLossTypes),
typeConverter=TypeConverters.toString)

@since("3.0.0")
Copy link
Member

@BryanCutler BryanCutler Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address the previous comment, to not change the since version since we are just refactoring the base class.

typeConverter=TypeConverters.toString)

@since("3.0.0")
def setLossType(self, value):
Copy link
Member

@BryanCutler BryanCutler Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address the above comment, this method should be in the estimator

typeConverter=TypeConverters.toString)

@since("1.4.0")
def setLossType(self, value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setLossType should be in the estimator and getLossType should be here

@huaxingao
Copy link
Contributor Author

@BryanCutler Thank you very much for your review! I will submit changes soon.

@SparkQA
Copy link

SparkQA commented Dec 5, 2018

Test build #99744 has finished for PR 21465 at commit 30a743d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -814,3 +814,25 @@ def getDistanceMeasure(self):
"""
return self.getOrDefault(self.distanceMeasure)


class HasValidationIndicatorCol(Params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind running the codegen again, like this command for example
pushd python/pyspark/ml/param/ && python _shared_params_code_gen.py > shared.py && popd and push the result if there is a diff? I think the DecisionTreeParams should be at the bottom of the file..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. DecisionTreeParams should be at the bottom.

@SparkQA
Copy link

SparkQA commented Dec 7, 2018

Test build #99838 has finished for PR 21465 at commit 6fc95a7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class HasDistanceMeasure(Params):
  • class HasValidationIndicatorCol(Params):

Copy link
Member

@BryanCutler BryanCutler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@BryanCutler
Copy link
Member

merged to master, thanks @huaxingao !

@asfgit asfgit closed this in 20278e7 Dec 7, 2018
@huaxingao
Copy link
Contributor Author

@BryanCutler Thank you very much for your help!

jackylee-ch pushed a commit to jackylee-ch/spark that referenced this pull request Feb 18, 2019
…: Python API

## What changes were proposed in this pull request?

Add validationIndicatorCol and validationTol to GBT Python.

## How was this patch tested?

Add test in doctest to test the new API.

Closes apache#21465 from huaxingao/spark-24333.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants