Skip to content

Commit 20278e7

Browse files
huaxingaoBryanCutler
authored andcommitted
[SPARK-24333][ML][PYTHON] Add fit with validation set to spark.ml GBT: Python API
## What changes were proposed in this pull request? Add validationIndicatorCol and validationTol to GBT Python. ## How was this patch tested? Add test in doctest to test the new API. Closes #21465 from huaxingao/spark-24333. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
1 parent 3b8ae23 commit 20278e7

File tree

4 files changed

+169
-88
lines changed

4 files changed

+169
-88
lines changed

python/pyspark/ml/classification.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pyspark.ml import Estimator, Model
2424
from pyspark.ml.param.shared import *
2525
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
26-
RandomForestParams, TreeEnsembleModel, TreeEnsembleParams
26+
GBTParams, HasVarianceImpurity, RandomForestParams, TreeEnsembleModel, TreeEnsembleParams
2727
from pyspark.ml.util import *
2828
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
2929
from pyspark.ml.wrapper import JavaWrapper
@@ -895,15 +895,6 @@ def getImpurity(self):
895895
return self.getOrDefault(self.impurity)
896896

897897

898-
class GBTParams(TreeEnsembleParams):
899-
"""
900-
Private class to track supported GBT params.
901-
902-
.. versionadded:: 1.4.0
903-
"""
904-
supportedLossTypes = ["logistic"]
905-
906-
907898
@inherit_doc
908899
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
909900
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
@@ -1174,9 +1165,31 @@ def trees(self):
11741165
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
11751166

11761167

1168+
class GBTClassifierParams(GBTParams, HasVarianceImpurity):
1169+
"""
1170+
Private class to track supported GBTClassifier params.
1171+
1172+
.. versionadded:: 3.0.0
1173+
"""
1174+
1175+
supportedLossTypes = ["logistic"]
1176+
1177+
lossType = Param(Params._dummy(), "lossType",
1178+
"Loss function which GBT tries to minimize (case-insensitive). " +
1179+
"Supported options: " + ", ".join(supportedLossTypes),
1180+
typeConverter=TypeConverters.toString)
1181+
1182+
@since("1.4.0")
1183+
def getLossType(self):
1184+
"""
1185+
Gets the value of lossType or its default value.
1186+
"""
1187+
return self.getOrDefault(self.lossType)
1188+
1189+
11771190
@inherit_doc
1178-
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
1179-
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
1191+
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
1192+
GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
11801193
JavaMLReadable):
11811194
"""
11821195
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
@@ -1242,40 +1255,36 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
12421255
[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
12431256
>>> model.numClasses
12441257
2
1258+
>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
1259+
>>> gbt.getValidationIndicatorCol()
1260+
'validationIndicator'
1261+
>>> gbt.getValidationTol()
1262+
0.01
12451263
12461264
.. versionadded:: 1.4.0
12471265
"""
12481266

1249-
lossType = Param(Params._dummy(), "lossType",
1250-
"Loss function which GBT tries to minimize (case-insensitive). " +
1251-
"Supported options: " + ", ".join(GBTParams.supportedLossTypes),
1252-
typeConverter=TypeConverters.toString)
1253-
1254-
stepSize = Param(Params._dummy(), "stepSize",
1255-
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
1256-
"the contribution of each estimator.",
1257-
typeConverter=TypeConverters.toFloat)
1258-
12591267
@keyword_only
12601268
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
12611269
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
12621270
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
1263-
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
1264-
featureSubsetStrategy="all"):
1271+
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
1272+
featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None):
12651273
"""
12661274
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
12671275
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
12681276
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
12691277
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
1270-
featureSubsetStrategy="all")
1278+
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1279+
validationIndicatorCol=None)
12711280
"""
12721281
super(GBTClassifier, self).__init__()
12731282
self._java_obj = self._new_java_obj(
12741283
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
12751284
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
12761285
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
12771286
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
1278-
featureSubsetStrategy="all")
1287+
impurity="variance", featureSubsetStrategy="all", validationTol=0.01)
12791288
kwargs = self._input_kwargs
12801289
self.setParams(**kwargs)
12811290

@@ -1285,13 +1294,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
12851294
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
12861295
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
12871296
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
1288-
featureSubsetStrategy="all"):
1297+
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
1298+
validationIndicatorCol=None):
12891299
"""
12901300
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
12911301
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
12921302
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
12931303
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
1294-
featureSubsetStrategy="all")
1304+
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
1305+
validationIndicatorCol=None)
12951306
Sets params for Gradient Boosted Tree Classification.
12961307
"""
12971308
kwargs = self._input_kwargs
@@ -1307,20 +1318,20 @@ def setLossType(self, value):
13071318
"""
13081319
return self._set(lossType=value)
13091320

1310-
@since("1.4.0")
1311-
def getLossType(self):
1312-
"""
1313-
Gets the value of lossType or its default value.
1314-
"""
1315-
return self.getOrDefault(self.lossType)
1316-
13171321
@since("2.4.0")
13181322
def setFeatureSubsetStrategy(self, value):
13191323
"""
13201324
Sets the value of :py:attr:`featureSubsetStrategy`.
13211325
"""
13221326
return self._set(featureSubsetStrategy=value)
13231327

1328+
@since("3.0.0")
1329+
def setValidationIndicatorCol(self, value):
1330+
"""
1331+
Sets the value of :py:attr:`validationIndicatorCol`.
1332+
"""
1333+
return self._set(validationIndicatorCol=value)
1334+
13241335

13251336
class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
13261337
JavaMLReadable):

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def get$Name(self):
164164
"False", "TypeConverters.toBoolean"),
165165
("loss", "the loss function to be optimized.", None, "TypeConverters.toString"),
166166
("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.",
167-
"'euclidean'", "TypeConverters.toString")]
167+
"'euclidean'", "TypeConverters.toString"),
168+
("validationIndicatorCol", "name of the column that indicates whether each row is for " +
169+
"training or for validation. False indicates training; true indicates validation.",
170+
None, "TypeConverters.toString")]
168171

169172
code = []
170173
for name, doc, defaultValueStr, typeConverter in shared:

python/pyspark/ml/param/shared.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,53 @@ def getLoss(self):
702702
return self.getOrDefault(self.loss)
703703

704704

705+
class HasDistanceMeasure(Params):
706+
"""
707+
Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
708+
"""
709+
710+
distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)
711+
712+
def __init__(self):
713+
super(HasDistanceMeasure, self).__init__()
714+
self._setDefault(distanceMeasure='euclidean')
715+
716+
def setDistanceMeasure(self, value):
717+
"""
718+
Sets the value of :py:attr:`distanceMeasure`.
719+
"""
720+
return self._set(distanceMeasure=value)
721+
722+
def getDistanceMeasure(self):
723+
"""
724+
Gets the value of distanceMeasure or its default value.
725+
"""
726+
return self.getOrDefault(self.distanceMeasure)
727+
728+
729+
class HasValidationIndicatorCol(Params):
730+
"""
731+
Mixin for param validationIndicatorCol: name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.
732+
"""
733+
734+
validationIndicatorCol = Param(Params._dummy(), "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.", typeConverter=TypeConverters.toString)
735+
736+
def __init__(self):
737+
super(HasValidationIndicatorCol, self).__init__()
738+
739+
def setValidationIndicatorCol(self, value):
740+
"""
741+
Sets the value of :py:attr:`validationIndicatorCol`.
742+
"""
743+
return self._set(validationIndicatorCol=value)
744+
745+
def getValidationIndicatorCol(self):
746+
"""
747+
Gets the value of validationIndicatorCol or its default value.
748+
"""
749+
return self.getOrDefault(self.validationIndicatorCol)
750+
751+
705752
class DecisionTreeParams(Params):
706753
"""
707754
Mixin for Decision Tree parameters.
@@ -790,27 +837,3 @@ def getCacheNodeIds(self):
790837
"""
791838
return self.getOrDefault(self.cacheNodeIds)
792839

793-
794-
class HasDistanceMeasure(Params):
795-
"""
796-
Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
797-
"""
798-
799-
distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)
800-
801-
def __init__(self):
802-
super(HasDistanceMeasure, self).__init__()
803-
self._setDefault(distanceMeasure='euclidean')
804-
805-
def setDistanceMeasure(self, value):
806-
"""
807-
Sets the value of :py:attr:`distanceMeasure`.
808-
"""
809-
return self._set(distanceMeasure=value)
810-
811-
def getDistanceMeasure(self):
812-
"""
813-
Gets the value of distanceMeasure or its default value.
814-
"""
815-
return self.getOrDefault(self.distanceMeasure)
816-

0 commit comments

Comments
 (0)