|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +""" |
| 19 | +This example uses random hyperparameters to perform model selection. |
| 20 | +Run with: |
| 21 | +
|
| 22 | + bin/spark-submit examples/src/main/python/ml/model_selection_random_hyperparameters_example.py |
| 23 | +""" |
| 24 | +# $example on$ |
| 25 | +from pyspark.ml.evaluation import RegressionEvaluator |
| 26 | +from pyspark.ml.regression import LinearRegression |
| 27 | +from pyspark.ml.tuning import ParamRandomBuilder, CrossValidator |
| 28 | +# $example off$ |
| 29 | +from pyspark.sql import SparkSession |
| 30 | + |
| 31 | +if __name__ == "__main__": |
| 32 | + spark = SparkSession \ |
| 33 | + .builder \ |
| 34 | + .appName("TrainValidationSplit") \ |
| 35 | + .getOrCreate() |
| 36 | + |
| 37 | + # $example on$ |
| 38 | + data = spark.read.format("libsvm") \ |
| 39 | + .load("data/mllib/sample_linear_regression_data.txt") |
| 40 | + |
| 41 | + lr = LinearRegression(maxIter=10) |
| 42 | + |
| 43 | + # We sample the regularization parameter logarithmically over the range [0.01, 1.0]. |
| 44 | + # This means that values around 0.01, 0.1 and 1.0 are roughly equally likely. |
| 45 | + # Note that both parameters must be greater than zero as otherwise we'll get an infinity. |
| 46 | + # We sample the the ElasticNet mixing parameter uniformly over the range [0, 1] |
| 47 | + # Note that in real life, you'd choose more than the 5 samples we see below. |
| 48 | + hyperparameters = ParamRandomBuilder() \ |
| 49 | + .addLog10Random(lr.regParam, 0.01, 1.0, 5) \ |
| 50 | + .addRandom(lr.elasticNetParam, 0.0, 1.0, 5) \ |
| 51 | + .addGrid(lr.fitIntercept, [False, True]) \ |
| 52 | + .build() |
| 53 | + |
| 54 | + cv = CrossValidator(estimator=lr, |
| 55 | + estimatorParamMaps=hyperparameters, |
| 56 | + evaluator=RegressionEvaluator(), |
| 57 | + numFolds=2) |
| 58 | + |
| 59 | + model = cv.fit(data) |
| 60 | + bestModel = model.bestModel |
| 61 | + print("Optimal model has regParam = {}, elasticNetParam = {}, fitIntercept = {}" |
| 62 | + .format(bestModel.getRegParam(), bestModel.getElasticNetParam(), |
| 63 | + bestModel.getFitIntercept())) |
| 64 | + |
| 65 | + # $example off$ |
| 66 | + spark.stop() |
0 commit comments