-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-24333][ML][PYTHON]Add fit with validation set to spark.ml GBT: Python API #21465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Test build #91317 has finished for PR 21465 at commit
|
python/pyspark/ml/classification.py
Outdated
@keyword_only | ||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, | ||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", | ||
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, | ||
featureSubsetStrategy="all"): | ||
featureSubsetStrategy="all", validationTol=0.01): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't validationIndicatorCol
be in init
too? Set to None default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MLnick Yes, I should add it in init. Will change it now. Thanks a lot for your review!
Test build #91586 has finished for PR 21465 at commit
|
Test build #91798 has finished for PR 21465 at commit
|
4290b58
to
1169db8
Compare
Test build #95893 has finished for PR 21465 at commit
|
Test build #98751 has finished for PR 21465 at commit
|
Test build #99013 has finished for PR 21465 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @huaxingao for the PR! I think the new params should be added in GBTParams
. While there, maybe you could add HasMaxIter and HasStepSize also, to match the Scala side.
python/pyspark/ml/classification.py
Outdated
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, | ||
JavaMLReadable): | ||
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, | ||
HasValidationIndicatorCol, JavaMLWritable, JavaMLReadable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be added to GBTParams
, which is done on the Scala side too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BryanCutler Thank you very much for reviewing my PR. I moved HasValidationIndicatorCol, HasMaxIter and HasStepSize to GBTParams.
Test build #99136 has finished for PR 21465 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @huaxingao but lets also add GBTClassifierParams
and GBTRegressorParams
to handle lossType as is done in Scala.
@@ -705,12 +705,38 @@ def getNumTrees(self): | |||
return self.getOrDefault(self.numTrees) | |||
|
|||
|
|||
class GBTParams(TreeEnsembleParams): | |||
class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like having a common GBTParams
class, it was strange to have this defined in both estimators. But you should also define GBTClassifierParams
and GBTRegressorParams
, then put the supportedLossTypes
in there so you don't need to override them later. You can also put the lossType
param and getLossType()
method there. This makes it clean and follows how it's done in Scala.
python/pyspark/ml/regression.py
Outdated
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, | ||
JavaMLReadable, TreeRegressorParams): | ||
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, GBTParams, | ||
HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, JavaMLReadable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove HasStepSize
since it is in GBTParams
88ff888
to
c0fcbb3
Compare
Test build #99408 has finished for PR 21465 at commit
|
Test build #99413 has finished for PR 21465 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@huaxingao there are quite a lot of deviations from how these classes are in Scala, please follow how the class hierarchy is defined there and it should all fit together.
python/pyspark/ml/classification.py
Outdated
"Supported options: " + ", ".join(supportedLossTypes), | ||
typeConverter=TypeConverters.toString) | ||
|
||
@since("3.0.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't change the version, since we are just refactoring the base classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please address the previous comment, to not change the since version since we are just refactoring the base class.
python/pyspark/ml/classification.py
Outdated
typeConverter=TypeConverters.toString) | ||
|
||
@since("3.0.0") | ||
def setLossType(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setLossType
should be in the estimators, getLossType
should be here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please address the above comment, this method should be in the estimator
@@ -1174,9 +1165,31 @@ def trees(self): | |||
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] | |||
|
|||
|
|||
class GBTClassifierParams(GBTParams, HasVarianceImpurity): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should extend TreeClassifierParams
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BryanCutler Thanks for your review.
Seems recently #22986 added trait HasVarianceImpurity
and made
private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see. let me take another look..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're correct, this is fine
@@ -650,19 +650,20 @@ def getFeatureSubsetStrategy(self): | |||
return self.getOrDefault(self.featureSubsetStrategy) | |||
|
|||
|
|||
class TreeRegressorParams(Params): | |||
class HasVarianceImpurity(Params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be changed, impurity is different for regression and classification, so the param needs to be defined in This is correct and matches Scala currentlyTreeRegressorParams
and TreeClassifierParams
, as it was already
@keyword_only | ||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, | ||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", | ||
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, | ||
featureSubsetStrategy="all"): | ||
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not the correct default impurity default value has been changed in Scala, this is correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please look at some of my previous comments and fix those, then I think it will be good to go, thanks!
python/pyspark/ml/regression.py
Outdated
typeConverter=TypeConverters.toFloat) | ||
|
||
@since("3.0.0") | ||
def setValidationTol(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems scala does not have this API right? If not then let's remove it here for now
@@ -1174,9 +1165,31 @@ def trees(self): | |||
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] | |||
|
|||
|
|||
class GBTClassifierParams(GBTParams, HasVarianceImpurity): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're correct, this is fine
python/pyspark/ml/classification.py
Outdated
"Supported options: " + ", ".join(supportedLossTypes), | ||
typeConverter=TypeConverters.toString) | ||
|
||
@since("3.0.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please address the previous comment, to not change the since version since we are just refactoring the base class.
python/pyspark/ml/classification.py
Outdated
typeConverter=TypeConverters.toString) | ||
|
||
@since("3.0.0") | ||
def setLossType(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please address the above comment, this method should be in the estimator
python/pyspark/ml/regression.py
Outdated
typeConverter=TypeConverters.toString) | ||
|
||
@since("1.4.0") | ||
def setLossType(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setLossType
should be in the estimator and getLossType
should be here
@BryanCutler Thank you very much for your review! I will submit changes soon. |
Test build #99744 has finished for PR 21465 at commit
|
python/pyspark/ml/param/shared.py
Outdated
@@ -814,3 +814,25 @@ def getDistanceMeasure(self): | |||
""" | |||
return self.getOrDefault(self.distanceMeasure) | |||
|
|||
|
|||
class HasValidationIndicatorCol(Params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind running the codegen again, like this command for example
pushd python/pyspark/ml/param/ && python _shared_params_code_gen.py > shared.py && popd
and push the result if there is a diff? I think the DecisionTreeParams should be at the bottom of the file..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. DecisionTreeParams should be at the bottom.
Test build #99838 has finished for PR 21465 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
merged to master, thanks @huaxingao ! |
@BryanCutler Thank you very much for your help! |
…: Python API ## What changes were proposed in this pull request? Add validationIndicatorCol and validationTol to GBT Python. ## How was this patch tested? Add test in doctest to test the new API. Closes apache#21465 from huaxingao/spark-24333. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
What changes were proposed in this pull request?
Add validationIndicatorCol and validationTol to GBT Python.
How was this patch tested?
Add test in doctest to test the new API.