|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from __future__ import print_function |
| 19 | + |
| 20 | +from pyspark import SparkContext |
| 21 | +from pyspark.ml import Pipeline |
| 22 | +from pyspark.ml.classification import LogisticRegression |
| 23 | +from pyspark.ml.evaluation import BinaryClassificationEvaluator |
| 24 | +from pyspark.ml.feature import HashingTF, Tokenizer |
| 25 | +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder |
| 26 | +from pyspark.sql import Row, SQLContext |
| 27 | + |
| 28 | +""" |
| 29 | +A simple example demonstrating model selection using CrossValidator. |
| 30 | +This example also demonstrates how Pipelines are Estimators. |
| 31 | +Run with: |
| 32 | +
|
| 33 | + bin/spark-submit examples/src/main/python/ml/cross_validator.py |
| 34 | +""" |
| 35 | + |
| 36 | +if __name__ == "__main__": |
| 37 | + sc = SparkContext(appName="CrossValidatorExample") |
| 38 | + sqlContext = SQLContext(sc) |
| 39 | + |
| 40 | + # Prepare training documents, which are labeled. |
| 41 | + LabeledDocument = Row("id", "text", "label") |
| 42 | + training = sc.parallelize([(0, "a b c d e spark", 1.0), |
| 43 | + (1, "b d", 0.0), |
| 44 | + (2, "spark f g h", 1.0), |
| 45 | + (3, "hadoop mapreduce", 0.0)]) \ |
| 46 | + .map(lambda x: LabeledDocument(*x)).toDF() |
| 47 | + |
| 48 | + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. |
| 49 | + tokenizer = Tokenizer(inputCol="text", outputCol="words") |
| 50 | + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") |
| 51 | + lr = LogisticRegression(maxIter=10, regParam=0.001) |
| 52 | + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) |
| 53 | + |
| 54 | + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. |
| 55 | + # This will allow us to jointly choose parameters for all Pipeline stages. |
| 56 | + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. |
| 57 | + # We use a ParamGridBuilder to construct a grid of parameters to search over. |
| 58 | + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, |
| 59 | + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. |
| 60 | + paramGrid = ParamGridBuilder() \ |
| 61 | + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ |
| 62 | + .addGrid(lr.regParam, [0.1, 0.01]) \ |
| 63 | + .build() |
| 64 | + |
| 65 | + crossval = CrossValidator(estimator=pipeline, |
| 66 | + estimatorParamMaps=paramGrid, |
| 67 | + evaluator=BinaryClassificationEvaluator(), |
| 68 | + numFolds=2) |
| 69 | + |
| 70 | + # Run cross-validation, and choose the best set of parameters. |
| 71 | + cvModel = crossval.fit(training) |
| 72 | + |
| 73 | + # Prepare test documents, which are unlabeled. |
| 74 | + Document = Row("id", "text") |
| 75 | + test = sc.parallelize([(4L, "spark i j k"), |
| 76 | + (5L, "l m n"), |
| 77 | + (6L, "mapreduce spark"), |
| 78 | + (7L, "apache hadoop")]) \ |
| 79 | + .map(lambda x: Document(*x)).toDF() |
| 80 | + |
| 81 | + # Make predictions on test documents. cvModel uses the best model found (lrModel). |
| 82 | + prediction = cvModel.transform(test) |
| 83 | + selected = prediction.select("id", "text", "probability", "prediction") |
| 84 | + for row in selected.collect(): |
| 85 | + print(row) |
| 86 | + |
| 87 | + sc.stop() |
0 commit comments