28
28
from pyspark .ml .wrapper import JavaEstimator , JavaModel , JavaParams
29
29
from pyspark .ml .wrapper import JavaWrapper
30
30
from pyspark .ml .common import inherit_doc , _java2py , _py2java
31
+ from pyspark .ml .linalg import Vectors
31
32
from pyspark .sql import DataFrame
32
33
from pyspark .sql .functions import udf , when
33
34
from pyspark .sql .types import ArrayType , DoubleType
@@ -1717,7 +1718,8 @@ def weights(self):
1717
1718
return self ._call_java ("weights" )
1718
1719
1719
1720
1720
- class OneVsRestParams (HasFeaturesCol , HasLabelCol , HasWeightCol , HasPredictionCol ):
1721
+ class OneVsRestParams (HasFeaturesCol , HasLabelCol , HasWeightCol , HasPredictionCol ,
1722
+ HasRawPredictionCol ):
1721
1723
"""
1722
1724
Parameters for OneVsRest and OneVsRestModel.
1723
1725
"""
@@ -1758,6 +1760,8 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java
1758
1760
>>> df = spark.read.format("libsvm").load(data_path)
1759
1761
>>> lr = LogisticRegression(regParam=0.01)
1760
1762
>>> ovr = OneVsRest(classifier=lr)
1763
+ >>> ovr.getRawPredictionCol()
1764
+ 'rawPrediction'
1761
1765
>>> model = ovr.fit(df)
1762
1766
>>> model.models[0].coefficients
1763
1767
DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
@@ -1781,16 +1785,18 @@ class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, Java
1781
1785
>>> model2 = OneVsRestModel.load(model_path)
1782
1786
>>> model2.transform(test0).head().prediction
1783
1787
0.0
1788
+ >>> model.transform(test2).columns
1789
+ ['features', 'rawPrediction', 'prediction']
1784
1790
1785
1791
.. versionadded:: 2.0.0
1786
1792
"""
1787
1793
1788
1794
@keyword_only
1789
1795
def __init__ (self , featuresCol = "features" , labelCol = "label" , predictionCol = "prediction" ,
1790
- classifier = None , weightCol = None , parallelism = 1 ):
1796
+ rawPredictionCol = "rawPrediction" , classifier = None , weightCol = None , parallelism = 1 ):
1791
1797
"""
1792
1798
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1793
- classifier=None, weightCol=None, parallelism=1):
1799
+ rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
1794
1800
"""
1795
1801
super (OneVsRest , self ).__init__ ()
1796
1802
self ._setDefault (parallelism = 1 )
@@ -1800,10 +1806,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
1800
1806
@keyword_only
1801
1807
@since ("2.0.0" )
1802
1808
def setParams (self , featuresCol = "features" , labelCol = "label" , predictionCol = "prediction" ,
1803
- classifier = None , weightCol = None , parallelism = 1 ):
1809
+ rawPredictionCol = "rawPrediction" , classifier = None , weightCol = None , parallelism = 1 ):
1804
1810
"""
1805
1811
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1806
- classifier=None, weightCol=None, parallelism=1):
1812
+ rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
1807
1813
Sets params for OneVsRest.
1808
1814
"""
1809
1815
kwargs = self ._input_kwargs
@@ -1814,8 +1820,6 @@ def _fit(self, dataset):
1814
1820
featuresCol = self .getFeaturesCol ()
1815
1821
predictionCol = self .getPredictionCol ()
1816
1822
classifier = self .getClassifier ()
1817
- assert isinstance (classifier , HasRawPredictionCol ),\
1818
- "Classifier %s doesn't extend from HasRawPredictionCol." % type (classifier )
1819
1823
1820
1824
numClasses = int (dataset .agg ({labelCol : "max" }).head ()["max(" + labelCol + ")" ]) + 1
1821
1825
@@ -1884,10 +1888,12 @@ def _from_java(cls, java_stage):
1884
1888
featuresCol = java_stage .getFeaturesCol ()
1885
1889
labelCol = java_stage .getLabelCol ()
1886
1890
predictionCol = java_stage .getPredictionCol ()
1891
+ rawPredictionCol = java_stage .getRawPredictionCol ()
1887
1892
classifier = JavaParams ._from_java (java_stage .getClassifier ())
1888
1893
parallelism = java_stage .getParallelism ()
1889
1894
py_stage = cls (featuresCol = featuresCol , labelCol = labelCol , predictionCol = predictionCol ,
1890
- classifier = classifier , parallelism = parallelism )
1895
+ rawPredictionCol = rawPredictionCol , classifier = classifier ,
1896
+ parallelism = parallelism )
1891
1897
py_stage ._resetUid (java_stage .uid ())
1892
1898
return py_stage
1893
1899
@@ -1904,6 +1910,7 @@ def _to_java(self):
1904
1910
_java_obj .setFeaturesCol (self .getFeaturesCol ())
1905
1911
_java_obj .setLabelCol (self .getLabelCol ())
1906
1912
_java_obj .setPredictionCol (self .getPredictionCol ())
1913
+ _java_obj .setRawPredictionCol (self .getRawPredictionCol ())
1907
1914
return _java_obj
1908
1915
1909
1916
def _make_java_param_pair (self , param , value ):
@@ -1994,7 +2001,8 @@ def _transform(self, dataset):
1994
2001
# update the accumulator column with the result of prediction of models
1995
2002
aggregatedDataset = newDataset
1996
2003
for index , model in enumerate (self .models ):
1997
- rawPredictionCol = model ._call_java ("getRawPredictionCol" )
2004
+ rawPredictionCol = self .getRawPredictionCol ()
2005
+
1998
2006
columns = origCols + [rawPredictionCol , accColName ]
1999
2007
2000
2008
# add temporary column to store intermediate scores and update
@@ -2015,14 +2023,24 @@ def _transform(self, dataset):
2015
2023
if handlePersistence :
2016
2024
newDataset .unpersist ()
2017
2025
2018
- # output the index of the classifier with highest confidence as prediction
2019
- labelUDF = udf (
2020
- lambda predictions : float (max (enumerate (predictions ), key = operator .itemgetter (1 ))[0 ]),
2021
- DoubleType ())
2022
-
2023
- # output label and label metadata as prediction
2024
- return aggregatedDataset .withColumn (
2025
- self .getPredictionCol (), labelUDF (aggregatedDataset [accColName ])).drop (accColName )
2026
+ if self .getRawPredictionCol ():
2027
+ def func (predictions ):
2028
+ predArray = []
2029
+ for x in predictions :
2030
+ predArray .append (x )
2031
+ return Vectors .dense (predArray )
2032
+
2033
+ rawPredictionUDF = udf (func )
2034
+ aggregatedDataset = aggregatedDataset .withColumn (
2035
+ self .getRawPredictionCol (), rawPredictionUDF (aggregatedDataset [accColName ]))
2036
+
2037
+ if self .getPredictionCol ():
2038
+ # output the index of the classifier with highest confidence as prediction
2039
+ labelUDF = udf (lambda predictions : float (max (enumerate (predictions ),
2040
+ key = operator .itemgetter (1 ))[0 ]), DoubleType ())
2041
+ aggregatedDataset = aggregatedDataset .withColumn (
2042
+ self .getPredictionCol (), labelUDF (aggregatedDataset [accColName ]))
2043
+ return aggregatedDataset .drop (accColName )
2026
2044
2027
2045
@since ("2.0.0" )
2028
2046
def copy (self , extra = None ):
0 commit comments