-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-11629] [ML] [PySpark] [Doc] Python example code for Multilayer Perceptron Classification #9594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[SPARK-11629] [ML] [PySpark] [Doc] Python example code for Multilayer Perceptron Classification #9594
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
...src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.ml; | ||
|
||
// $example on$ | ||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.sql.SQLContext; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; | ||
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; | ||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
import org.apache.spark.sql.DataFrame; | ||
// $example off$ | ||
|
||
/** | ||
* An example for Multilayer Perceptron Classification. | ||
*/ | ||
public class JavaMultilayerPerceptronClassifierExample { | ||
|
||
public static void main(String[] args) { | ||
SparkConf conf = new SparkConf().setAppName("JavaMultilayerPerceptronClassifierExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(conf); | ||
SQLContext jsql = new SQLContext(jsc); | ||
|
||
// $example on$ | ||
// Load training data | ||
String path = "data/mllib/sample_multiclass_classification_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD(); | ||
DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); | ||
// Split the data into train and test | ||
DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); | ||
DataFrame train = splits[0]; | ||
DataFrame test = splits[1]; | ||
// specify layers for the neural network: | ||
// input layer of size 4 (features), two intermediate of size 5 and 4 | ||
// and output of size 3 (classes) | ||
int[] layers = new int[] {4, 5, 4, 3}; | ||
// create the trainer and set its parameters | ||
MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() | ||
.setLayers(layers) | ||
.setBlockSize(128) | ||
.setSeed(1234L) | ||
.setMaxIter(100); | ||
// train the model | ||
MultilayerPerceptronClassificationModel model = trainer.fit(train); | ||
// compute precision on the test set | ||
DataFrame result = model.transform(test); | ||
DataFrame predictionAndLabels = result.select("prediction", "label"); | ||
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() | ||
.setMetricName("precision"); | ||
System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); | ||
// $example off$ | ||
|
||
jsc.stop(); | ||
} | ||
} |
56 changes: 56 additions & 0 deletions
56
examples/src/main/python/ml/multilayer_perceptron_classification.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from __future__ import print_function | ||
|
||
from pyspark import SparkContext | ||
from pyspark.sql import SQLContext | ||
# $example on$ | ||
from pyspark.ml.classification import MultilayerPerceptronClassifier | ||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator | ||
from pyspark.mllib.util import MLUtils | ||
# $example off$ | ||
|
||
if __name__ == "__main__": | ||
|
||
sc = SparkContext(appName="multilayer_perceptron_classification_example") | ||
sqlContext = SQLContext(sc) | ||
|
||
# $example on$ | ||
# Load training data | ||
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt")\ | ||
.toDF() | ||
# Split the data into train and test | ||
splits = data.randomSplit([0.6, 0.4], 1234) | ||
train = splits[0] | ||
test = splits[1] | ||
# specify layers for the neural network: | ||
# input layer of size 4 (features), two intermediate of size 5 and 4 | ||
# and output of size 3 (classes) | ||
layers = [4, 5, 4, 3] | ||
# create the trainer and set its parameters | ||
trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234) | ||
# train the model | ||
model = trainer.fit(train) | ||
# compute precision on the test set | ||
result = model.transform(test) | ||
predictionAndLabels = result.select("prediction", "label") | ||
evaluator = MulticlassClassificationEvaluator(metricName="precision") | ||
print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) | ||
# $example off$ | ||
|
||
sc.stop() |
71 changes: 71 additions & 0 deletions
71
...s/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
// scalastyle:off println | ||
package org.apache.spark.examples.ml | ||
|
||
import org.apache.spark.{SparkContext, SparkConf} | ||
import org.apache.spark.sql.SQLContext | ||
// $example on$ | ||
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier | ||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator | ||
import org.apache.spark.mllib.util.MLUtils | ||
// $example off$ | ||
|
||
/** | ||
* An example for Multilayer Perceptron Classification. | ||
*/ | ||
object MultilayerPerceptronClassifierExample { | ||
|
||
def main(args: Array[String]): Unit = { | ||
val conf = new SparkConf().setAppName("MultilayerPerceptronClassifierExample") | ||
val sc = new SparkContext(conf) | ||
val sqlContext = new SQLContext(sc) | ||
import sqlContext.implicits._ | ||
|
||
// $example on$ | ||
// Load training data | ||
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") | ||
.toDF() | ||
// Split the data into train and test | ||
val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) | ||
val train = splits(0) | ||
val test = splits(1) | ||
// specify layers for the neural network: | ||
// input layer of size 4 (features), two intermediate of size 5 and 4 | ||
// and output of size 3 (classes) | ||
val layers = Array[Int](4, 5, 4, 3) | ||
// create the trainer and set its parameters | ||
val trainer = new MultilayerPerceptronClassifier() | ||
.setLayers(layers) | ||
.setBlockSize(128) | ||
.setSeed(1234L) | ||
.setMaxIter(100) | ||
// train the model | ||
val model = trainer.fit(train) | ||
// compute precision on the test set | ||
val result = model.transform(test) | ||
val predictionAndLabels = result.select("prediction", "label") | ||
val evaluator = new MulticlassClassificationEvaluator() | ||
.setMetricName("precision") | ||
println("Precision:" + evaluator.evaluate(predictionAndLabels)) | ||
// $example off$ | ||
|
||
sc.stop() | ||
} | ||
} | ||
// scalastyle:off println |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not part of this PR, but we should replace
MLUtils.loadLibSVMFile
bysqlContext.read.format("libsvm").load(...)
. Could you submit another PR after this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I will send a PR.