Skip to content

Commit 8ae338f

Browse files
committed
[SPARK-30358][ML][Pyspark][FOLLOWUP] ML expose predictRaw and predictProbability on Python side
1 parent ce7a49f commit 8ae338f

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

python/pyspark/ml/classification.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ def numClasses(self):
9696
"""
9797
return self._call_java("numClasses")
9898

99+
@since("3.0.0")
100+
def predictRaw(self, value):
101+
"""
102+
Raw prediction for each possible label.
103+
"""
104+
return self._call_java("predictRaw", value)
105+
99106

100107
class _JavaProbabilisticClassifierParams(HasProbabilityCol, HasThresholds, _JavaClassifierParams):
101108
"""
@@ -149,6 +156,13 @@ def setThresholds(self, value):
149156
"""
150157
return self._set(thresholds=value)
151158

159+
@since("3.0.0")
160+
def predictProbability(self, value):
161+
"""
162+
Predict the probability of each class given the features.
163+
"""
164+
return self._call_java("predictProbability", value)
165+
152166

153167
class _LinearSVCParams(_JavaClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol,
154168
HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold):
@@ -211,6 +225,8 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
211225
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
212226
>>> model.predict(test0.head().features)
213227
1.0
228+
>>> model.predictRaw(test0.head().features)
229+
DenseVector([-1.4831, 1.4831])
214230
>>> result = model.transform(test0).head()
215231
>>> result.newPrediction
216232
1.0
@@ -568,6 +584,10 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams,
568584
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
569585
>>> blorModel.predict(test0.head().features)
570586
1.0
587+
>>> blorModel.predictRaw(test0.head().features)
588+
DenseVector([-3.54..., 3.54...])
589+
>>> blorModel.predictProbability(test0.head().features)
590+
DenseVector([0.028, 0.972])
571591
>>> result = blorModel.transform(test0).head()
572592
>>> result.prediction
573593
1.0
@@ -1148,6 +1168,10 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
11481168
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
11491169
>>> model.predict(test0.head().features)
11501170
0.0
1171+
>>> model.predictRaw(test0.head().features)
1172+
DenseVector([1.0, 0.0])
1173+
>>> model.predictProbability(test0.head().features)
1174+
DenseVector([1.0, 0.0])
11511175
>>> result = model.transform(test0).head()
11521176
>>> result.prediction
11531177
0.0
@@ -1379,6 +1403,10 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
13791403
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
13801404
>>> model.predict(test0.head().features)
13811405
0.0
1406+
>>> model.predictRaw(test0.head().features)
1407+
DenseVector([2.0, 0.0])
1408+
>>> model.predictProbability(test0.head().features)
1409+
DenseVector([1.0, 0.0])
13821410
>>> result = model.transform(test0).head()
13831411
>>> result.prediction
13841412
0.0
@@ -1640,6 +1668,10 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
16401668
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
16411669
>>> model.predict(test0.head().features)
16421670
0.0
1671+
>>> model.predictRaw(test0.head().features)
1672+
DenseVector([1.1697, -1.1697])
1673+
>>> model.predictProbability(test0.head().features)
1674+
DenseVector([0.9121, 0.0879])
16431675
>>> result = model.transform(test0).head()
16441676
>>> result.prediction
16451677
0.0
@@ -1959,6 +1991,10 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
19591991
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
19601992
>>> model.predict(test0.head().features)
19611993
1.0
1994+
>>> model.predictRaw(test0.head().features)
1995+
DenseVector([-1.72..., -0.99...])
1996+
>>> model.predictProbability(test0.head().features)
1997+
DenseVector([0.32..., 0.67...])
19621998
>>> result = model.transform(test0).head()
19631999
>>> result.prediction
19642000
1.0
@@ -2174,6 +2210,10 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
21742210
... (Vectors.dense([0.0, 0.0]),)], ["features"])
21752211
>>> model.predict(testDF.head().features)
21762212
1.0
2213+
>>> model.predictRaw(testDF.head().features)
2214+
DenseVector([-16.208, 16.344])
2215+
>>> model.predictProbability(testDF.head().features)
2216+
DenseVector([0.0, 1.0])
21772217
>>> model.transform(testDF).select("features", "prediction").show()
21782218
+---------+----------+
21792219
| features|prediction|
@@ -2791,6 +2831,10 @@ class FMClassifier(JavaProbabilisticClassifier, _FactorizationMachinesParams, Ja
27912831
... (Vectors.dense(0.5),),
27922832
... (Vectors.dense(1.0),),
27932833
... (Vectors.dense(2.0),)], ["features"])
2834+
>>> model.predictRaw(test0.head().features)
2835+
DenseVector([22.13..., -22.13...])
2836+
>>> model.predictProbability(test0.head().features)
2837+
DenseVector([1.0, 0.0])
27942838
>>> model.transform(test0).select("features", "probability").show(10, False)
27952839
+--------+------------------------------------------+
27962840
|features|probability |

0 commit comments

Comments
 (0)