Skip to content

Commit 05e3e40

Browse files
committed
update example
1 parent d3e8dbe commit 05e3e40

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

examples/src/main/python/ml/simple_text_classification_pipeline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@
2222
from pyspark.ml.classification import LogisticRegression
2323

2424

25+
"""
26+
A simple text classification pipeline that recognizes "spark" from
27+
input text. This is to show how to create and configure a Spark ML
28+
pipeline in Python. Run with:
29+
30+
bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py
31+
"""
32+
33+
2534
if __name__ == "__main__":
2635
sc = SparkContext(appName="SimpleTextClassificationPipeline")
2736
sqlCtx = SQLContext(sc)
@@ -53,5 +62,9 @@
5362
(7L, "apache hadoop")])
5463
.map(lambda x: Row(id=x[0], text=x[1])))
5564

56-
for row in model.transform(test).collect():
65+
prediction = model.transform(test)
66+
67+
prediction.registerTempTable("prediction")
68+
selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
69+
for row in selected.collect():
5770
print row

python/pyspark/ml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def getStages(self):
144144

145145
def fit(self, dataset, params={}):
146146
paramMap = self._merge_params(params)
147-
stages = paramMap(self.stages)
147+
stages = paramMap[self.stages]
148148
for stage in stages:
149149
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
150150
raise ValueError(

0 commit comments

Comments
 (0)