Skip to content

Commit 7246d35

Browse files
author
Ram Sriharsha
committed
[SPARK-7387][ml][doc] CrossValidator example code in Python
1 parent 4e5220c commit 7246d35

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from __future__ import print_function
19+
20+
from pyspark import SparkContext
21+
from pyspark.ml import Pipeline
22+
from pyspark.ml.classification import LogisticRegression
23+
from pyspark.ml.evaluation import BinaryClassificationEvaluator
24+
from pyspark.ml.feature import HashingTF, Tokenizer
25+
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
26+
from pyspark.sql import Row, SQLContext
27+
28+
"""
29+
A simple example demonstrating model selection using CrossValidator.
30+
This example also demonstrates how Pipelines are Estimators.
31+
Run with:
32+
33+
bin/spark-submit examples/src/main/python/ml/cross_validator.py
34+
"""
35+
36+
if __name__ == "__main__":
37+
sc = SparkContext(appName="CrossValidatorExample")
38+
sqlContext = SQLContext(sc)
39+
40+
# Prepare training documents, which are labeled.
41+
LabeledDocument = Row("id", "text", "label")
42+
training = sc.parallelize([(0, "a b c d e spark", 1.0),
43+
(1, "b d", 0.0),
44+
(2, "spark f g h", 1.0),
45+
(3, "hadoop mapreduce", 0.0)]) \
46+
.map(lambda x: LabeledDocument(*x)).toDF()
47+
48+
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
49+
tokenizer = Tokenizer(inputCol="text", outputCol="words")
50+
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
51+
lr = LogisticRegression(maxIter=10, regParam=0.001)
52+
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
53+
54+
# We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
55+
# This will allow us to jointly choose parameters for all Pipeline stages.
56+
# A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
57+
# We use a ParamGridBuilder to construct a grid of parameters to search over.
58+
# With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
59+
# this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
60+
paramGrid = ParamGridBuilder() \
61+
.addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
62+
.addGrid(lr.regParam, [0.1, 0.01]) \
63+
.build()
64+
65+
crossval = CrossValidator(estimator=pipeline,
66+
estimatorParamMaps=paramGrid,
67+
evaluator=BinaryClassificationEvaluator(),
68+
numFolds=2)
69+
70+
# Run cross-validation, and choose the best set of parameters.
71+
cvModel = crossval.fit(training)
72+
73+
# Prepare test documents, which are unlabeled.
74+
Document = Row("id", "text")
75+
test = sc.parallelize([(4L, "spark i j k"),
76+
(5L, "l m n"),
77+
(6L, "mapreduce spark"),
78+
(7L, "apache hadoop")]) \
79+
.map(lambda x: Document(*x)).toDF()
80+
81+
# Make predictions on test documents. cvModel uses the best model found (lrModel).
82+
prediction = cvModel.transform(test)
83+
selected = prediction.select("id", "text", "probability", "prediction")
84+
for row in selected.collect():
85+
print(row)
86+
87+
sc.stop()

0 commit comments

Comments
 (0)