You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[SPARK-24333][ML][PYTHON] Add fit with validation set to spark.ml GBT: Python API
## What changes were proposed in this pull request?
Add validationIndicatorCol and validationTol to GBT Python.
## How was this patch tested?
Add test in doctest to test the new API.
Closes#21465 from huaxingao/spark-24333.
Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
Copy file name to clipboardExpand all lines: python/pyspark/ml/param/shared.py
+47-24Lines changed: 47 additions & 24 deletions
Original file line number
Diff line number
Diff line change
@@ -702,6 +702,53 @@ def getLoss(self):
702
702
returnself.getOrDefault(self.loss)
703
703
704
704
705
+
classHasDistanceMeasure(Params):
706
+
"""
707
+
Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
708
+
"""
709
+
710
+
distanceMeasure=Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)
711
+
712
+
def__init__(self):
713
+
super(HasDistanceMeasure, self).__init__()
714
+
self._setDefault(distanceMeasure='euclidean')
715
+
716
+
defsetDistanceMeasure(self, value):
717
+
"""
718
+
Sets the value of :py:attr:`distanceMeasure`.
719
+
"""
720
+
returnself._set(distanceMeasure=value)
721
+
722
+
defgetDistanceMeasure(self):
723
+
"""
724
+
Gets the value of distanceMeasure or its default value.
725
+
"""
726
+
returnself.getOrDefault(self.distanceMeasure)
727
+
728
+
729
+
classHasValidationIndicatorCol(Params):
730
+
"""
731
+
Mixin for param validationIndicatorCol: name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.
732
+
"""
733
+
734
+
validationIndicatorCol=Param(Params._dummy(), "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.", typeConverter=TypeConverters.toString)
735
+
736
+
def__init__(self):
737
+
super(HasValidationIndicatorCol, self).__init__()
738
+
739
+
defsetValidationIndicatorCol(self, value):
740
+
"""
741
+
Sets the value of :py:attr:`validationIndicatorCol`.
742
+
"""
743
+
returnself._set(validationIndicatorCol=value)
744
+
745
+
defgetValidationIndicatorCol(self):
746
+
"""
747
+
Gets the value of validationIndicatorCol or its default value.
0 commit comments