Skip to content

Commit 0116dee

Browse files
committed
[SPARK-2433][MLLIB] fix NaiveBayesModel.predict
This is the same as apache#463 , which I forgot to merge into branch-0.9. Author: Xiangrui Meng <meng@databricks.com> Closes apache#1453 from mengxr/nb-transpose-0.9 and squashes the following commits: bc53ce8 [Xiangrui Meng] fix NaiveBayes
1 parent 8e5604b commit 0116dee

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/pyspark/mllib/classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class NaiveBayesModel(object):
8484
- pi: vector of logs of class priors (dimension C)
8585
- theta: matrix of logs of class conditional probabilities (CxD)
8686
87-
>>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
87+
>>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 1.0]).reshape(3,3)
8888
>>> model = NaiveBayes.train(sc.parallelize(data))
8989
>>> model.predict(array([0.0, 1.0]))
9090
0
@@ -98,7 +98,7 @@ def __init__(self, pi, theta):
9898

9999
def predict(self, x):
100100
"""Return the most likely class for a data vector x"""
101-
return numpy.argmax(self.pi + dot(x, self.theta))
101+
return numpy.argmax(self.pi + dot(x, self.theta.transpose()))
102102

103103
class NaiveBayes(object):
104104
@classmethod

0 commit comments

Comments
 (0)