@@ -96,6 +96,13 @@ def numClasses(self):
96
96
"""
97
97
return self ._call_java ("numClasses" )
98
98
99
+ @since ("3.0.0" )
100
+ def predictRaw (self , value ):
101
+ """
102
+ Raw prediction for each possible label.
103
+ """
104
+ return self ._call_java ("predictRaw" , value )
105
+
99
106
100
107
class _JavaProbabilisticClassifierParams (HasProbabilityCol , HasThresholds , _JavaClassifierParams ):
101
108
"""
@@ -149,6 +156,13 @@ def setThresholds(self, value):
149
156
"""
150
157
return self ._set (thresholds = value )
151
158
159
+ @since ("3.0.0" )
160
+ def predictProbability (self , value ):
161
+ """
162
+ Predict the probability of each class given the features.
163
+ """
164
+ return self ._call_java ("predictProbability" , value )
165
+
152
166
153
167
class _LinearSVCParams (_JavaClassifierParams , HasRegParam , HasMaxIter , HasFitIntercept , HasTol ,
154
168
HasStandardization , HasWeightCol , HasAggregationDepth , HasThreshold ):
@@ -211,6 +225,8 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
211
225
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
212
226
>>> model.predict(test0.head().features)
213
227
1.0
228
+ >>> model.predictRaw(test0.head().features)
229
+ DenseVector([-1.4831, 1.4831])
214
230
>>> result = model.transform(test0).head()
215
231
>>> result.newPrediction
216
232
1.0
@@ -568,6 +584,10 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams,
568
584
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
569
585
>>> blorModel.predict(test0.head().features)
570
586
1.0
587
+ >>> blorModel.predictRaw(test0.head().features)
588
+ DenseVector([-3.54..., 3.54...])
589
+ >>> blorModel.predictProbability(test0.head().features)
590
+ DenseVector([0.028, 0.972])
571
591
>>> result = blorModel.transform(test0).head()
572
592
>>> result.prediction
573
593
1.0
@@ -1148,6 +1168,10 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
1148
1168
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1149
1169
>>> model.predict(test0.head().features)
1150
1170
0.0
1171
+ >>> model.predictRaw(test0.head().features)
1172
+ DenseVector([1.0, 0.0])
1173
+ >>> model.predictProbability(test0.head().features)
1174
+ DenseVector([1.0, 0.0])
1151
1175
>>> result = model.transform(test0).head()
1152
1176
>>> result.prediction
1153
1177
0.0
@@ -1379,6 +1403,10 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
1379
1403
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1380
1404
>>> model.predict(test0.head().features)
1381
1405
0.0
1406
+ >>> model.predictRaw(test0.head().features)
1407
+ DenseVector([2.0, 0.0])
1408
+ >>> model.predictProbability(test0.head().features)
1409
+ DenseVector([1.0, 0.0])
1382
1410
>>> result = model.transform(test0).head()
1383
1411
>>> result.prediction
1384
1412
0.0
@@ -1640,6 +1668,10 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
1640
1668
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
1641
1669
>>> model.predict(test0.head().features)
1642
1670
0.0
1671
+ >>> model.predictRaw(test0.head().features)
1672
+ DenseVector([1.1697, -1.1697])
1673
+ >>> model.predictProbability(test0.head().features)
1674
+ DenseVector([0.9121, 0.0879])
1643
1675
>>> result = model.transform(test0).head()
1644
1676
>>> result.prediction
1645
1677
0.0
@@ -1959,6 +1991,10 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
1959
1991
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
1960
1992
>>> model.predict(test0.head().features)
1961
1993
1.0
1994
+ >>> model.predictRaw(test0.head().features)
1995
+ DenseVector([-1.72..., -0.99...])
1996
+ >>> model.predictProbability(test0.head().features)
1997
+ DenseVector([0.32..., 0.67...])
1962
1998
>>> result = model.transform(test0).head()
1963
1999
>>> result.prediction
1964
2000
1.0
@@ -2174,6 +2210,10 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
2174
2210
... (Vectors.dense([0.0, 0.0]),)], ["features"])
2175
2211
>>> model.predict(testDF.head().features)
2176
2212
1.0
2213
+ >>> model.predictRaw(testDF.head().features)
2214
+ DenseVector([-16.208, 16.344])
2215
+ >>> model.predictProbability(testDF.head().features)
2216
+ DenseVector([0.0, 1.0])
2177
2217
>>> model.transform(testDF).select("features", "prediction").show()
2178
2218
+---------+----------+
2179
2219
| features|prediction|
@@ -2791,6 +2831,10 @@ class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, Ja
2791
2831
... (Vectors.dense(0.5),),
2792
2832
... (Vectors.dense(1.0),),
2793
2833
... (Vectors.dense(2.0),)], ["features"])
2834
+ >>> model.predictRaw(test0.head().features)
2835
+ DenseVector([22.13..., -22.13...])
2836
+ >>> model.predictProbability(test0.head().features)
2837
+ DenseVector([1.0, 0.0])
2794
2838
>>> model.transform(test0).select("features", "probability").show(10, False)
2795
2839
+--------+------------------------------------------+
2796
2840
|features|probability |
0 commit comments