Skip to content

Commit be5d95a

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-27007][PYTHON] add rawPrediction to OneVsRest in PySpark
## What changes were proposed in this pull request? Add RawPrediction to OneVsRest in PySpark to make it consistent with scala implementation ## How was this patch tested? Add doctest Closes #23910 from huaxingao/spark-27007. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent a97a19d commit be5d95a

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

python/pyspark/ml/classification.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
2929
from pyspark.ml.wrapper import JavaWrapper
3030
from pyspark.ml.common import inherit_doc, _java2py, _py2java
31+
from pyspark.ml.linalg import Vectors
3132
from pyspark.sql import DataFrame
3233
from pyspark.sql.functions import udf, when
3334
from pyspark.sql.types import ArrayType, DoubleType
@@ -1717,7 +1718,8 @@ def weights(self):
17171718
return self._call_java("weights")
17181719

17191720

1720-
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
1721+
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol,
1722+
HasRawPredictionCol):
17211723
"""
17221724
Parameters for OneVsRest and OneVsRestModel.
17231725
"""
@@ -1758,6 +1760,8 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java
17581760
>>> df = spark.read.format("libsvm").load(data_path)
17591761
>>> lr = LogisticRegression(regParam=0.01)
17601762
>>> ovr = OneVsRest(classifier=lr)
1763+
>>> ovr.getRawPredictionCol()
1764+
'rawPrediction'
17611765
>>> model = ovr.fit(df)
17621766
>>> model.models[0].coefficients
17631767
DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
@@ -1781,16 +1785,18 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java
17811785
>>> model2 = OneVsRestModel.load(model_path)
17821786
>>> model2.transform(test0).head().prediction
17831787
0.0
1788+
>>> model.transform(test2).columns
1789+
['features', 'rawPrediction', 'prediction']
17841790
17851791
.. versionadded:: 2.0.0
17861792
"""
17871793

17881794
@keyword_only
17891795
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1790-
classifier=None, weightCol=None, parallelism=1):
1796+
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
17911797
"""
17921798
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1793-
classifier=None, weightCol=None, parallelism=1):
1799+
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
17941800
"""
17951801
super(OneVsRest, self).__init__()
17961802
self._setDefault(parallelism=1)
@@ -1800,10 +1806,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
18001806
@keyword_only
18011807
@since("2.0.0")
18021808
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1803-
classifier=None, weightCol=None, parallelism=1):
1809+
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
18041810
"""
18051811
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1806-
classifier=None, weightCol=None, parallelism=1):
1812+
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
18071813
Sets params for OneVsRest.
18081814
"""
18091815
kwargs = self._input_kwargs
@@ -1814,8 +1820,6 @@ def _fit(self, dataset):
18141820
featuresCol = self.getFeaturesCol()
18151821
predictionCol = self.getPredictionCol()
18161822
classifier = self.getClassifier()
1817-
assert isinstance(classifier, HasRawPredictionCol),\
1818-
"Classifier %s doesn't extend from HasRawPredictionCol." % type(classifier)
18191823

18201824
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
18211825

@@ -1884,10 +1888,12 @@ def _from_java(cls, java_stage):
18841888
featuresCol = java_stage.getFeaturesCol()
18851889
labelCol = java_stage.getLabelCol()
18861890
predictionCol = java_stage.getPredictionCol()
1891+
rawPredictionCol = java_stage.getRawPredictionCol()
18871892
classifier = JavaParams._from_java(java_stage.getClassifier())
18881893
parallelism = java_stage.getParallelism()
18891894
py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
1890-
classifier=classifier, parallelism=parallelism)
1895+
rawPredictionCol=rawPredictionCol, classifier=classifier,
1896+
parallelism=parallelism)
18911897
py_stage._resetUid(java_stage.uid())
18921898
return py_stage
18931899

@@ -1904,6 +1910,7 @@ def _to_java(self):
19041910
_java_obj.setFeaturesCol(self.getFeaturesCol())
19051911
_java_obj.setLabelCol(self.getLabelCol())
19061912
_java_obj.setPredictionCol(self.getPredictionCol())
1913+
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
19071914
return _java_obj
19081915

19091916
def _make_java_param_pair(self, param, value):
@@ -1994,7 +2001,8 @@ def _transform(self, dataset):
19942001
# update the accumulator column with the result of prediction of models
19952002
aggregatedDataset = newDataset
19962003
for index, model in enumerate(self.models):
1997-
rawPredictionCol = model._call_java("getRawPredictionCol")
2004+
rawPredictionCol = self.getRawPredictionCol()
2005+
19982006
columns = origCols + [rawPredictionCol, accColName]
19992007

20002008
# add temporary column to store intermediate scores and update
@@ -2015,14 +2023,24 @@ def _transform(self, dataset):
20152023
if handlePersistence:
20162024
newDataset.unpersist()
20172025

2018-
# output the index of the classifier with highest confidence as prediction
2019-
labelUDF = udf(
2020-
lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]),
2021-
DoubleType())
2022-
2023-
# output label and label metadata as prediction
2024-
return aggregatedDataset.withColumn(
2025-
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)
2026+
if self.getRawPredictionCol():
2027+
def func(predictions):
2028+
predArray = []
2029+
for x in predictions:
2030+
predArray.append(x)
2031+
return Vectors.dense(predArray)
2032+
2033+
rawPredictionUDF = udf(func)
2034+
aggregatedDataset = aggregatedDataset.withColumn(
2035+
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
2036+
2037+
if self.getPredictionCol():
2038+
# output the index of the classifier with highest confidence as prediction
2039+
labelUDF = udf(lambda predictions: float(max(enumerate(predictions),
2040+
key=operator.itemgetter(1))[0]), DoubleType())
2041+
aggregatedDataset = aggregatedDataset.withColumn(
2042+
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName]))
2043+
return aggregatedDataset.drop(accColName)
20262044

20272045
@since("2.0.0")
20282046
def copy(self, extra=None):

python/pyspark/ml/tests/test_algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_output_columns(self):
114114
ovr = OneVsRest(classifier=lr, parallelism=1)
115115
model = ovr.fit(df)
116116
output = model.transform(df)
117-
self.assertEqual(output.columns, ["label", "features", "prediction"])
117+
self.assertEqual(output.columns, ["label", "features", "rawPrediction", "prediction"])
118118

119119
def test_parallelism_doesnt_change_output(self):
120120
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),

0 commit comments

Comments
 (0)