Skip to content

Commit 3f00bb3

Browse files
MechCodermengxr
authored andcommitted
[SPARK-6083] [MLLib] [DOC] Make Python API example consistent in NaiveBayes
Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4834 from MechCoder/spark-6083 and squashes the following commits: 1cdd7b5 [MechCoder] Add parse function 65bbbe9 [MechCoder] [SPARK-6083] Make Python API example consistent in NaiveBayes
1 parent aedbbaa commit 3f00bb3

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

docs/mllib-naive-bayes.md

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,28 @@ used for evaluation and prediction.
115115

116116
Note that the Python API does not yet support model save/load but will in the future.
117117

118-
<!-- TODO: Make Python's example consistent with Scala's and Java's. -->
119118
{% highlight python %}
120-
from pyspark.mllib.regression import LabeledPoint
121119
from pyspark.mllib.classification import NaiveBayes
120+
from pyspark.mllib.linalg import Vectors
121+
from pyspark.mllib.regression import LabeledPoint
122+
123+
def parseLine(line):
124+
parts = line.split(',')
125+
label = float(parts[0])
126+
features = Vectors.dense([float(x) for x in parts[1].split(' ')])
127+
return LabeledPoint(label, features)
128+
129+
data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine)
122130

123-
# an RDD of LabeledPoint
124-
data = sc.parallelize([
125-
LabeledPoint(0.0, [0.0, 0.0])
126-
... # more labeled points
127-
])
131+
# Split data aproximately into training (60%) and test (40%)
132+
training, test = data.randomSplit([0.6, 0.4], seed = 0)
128133

129134
# Train a naive Bayes model.
130-
model = NaiveBayes.train(data, 1.0)
135+
model = NaiveBayes.train(training, 1.0)
131136

132-
# Make prediction.
133-
prediction = model.predict([0.0, 0.0])
137+
# Make prediction and test accuracy.
138+
predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label))
139+
accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
134140
{% endhighlight %}
135141

136142
</div>

0 commit comments

Comments
 (0)