|
16 | 16 | #
|
17 | 17 |
|
18 | 18 | """
|
19 |
| -An example of how to use DataFrame as a dataset for ML. Run with:: |
20 |
| - bin/spark-submit examples/src/main/python/mllib/dataset_example.py |
| 19 | +An example of how to use DataFrame for ML. Run with:: |
| 20 | + bin/spark-submit examples/src/main/python/ml/dataframe_example.py <input> |
21 | 21 | """
|
22 | 22 | from __future__ import print_function
|
23 | 23 |
|
|
28 | 28 |
|
29 | 29 | from pyspark import SparkContext
|
30 | 30 | from pyspark.sql import SQLContext
|
31 |
| -from pyspark.mllib.util import MLUtils |
32 | 31 | from pyspark.mllib.stat import Statistics
|
33 | 32 |
|
34 |
| - |
35 |
| -def summarize(dataset): |
36 |
| - print("schema: %s" % dataset.schema().json()) |
37 |
| - labels = dataset.map(lambda r: r.label) |
38 |
| - print("label average: %f" % labels.mean()) |
39 |
| - features = dataset.map(lambda r: r.features) |
40 |
| - summary = Statistics.colStats(features) |
41 |
| - print("features average: %r" % summary.mean()) |
42 |
| - |
43 | 33 | if __name__ == "__main__":
|
44 | 34 | if len(sys.argv) > 2:
|
45 |
| - print("Usage: dataset_example.py <libsvm file>", file=sys.stderr) |
| 35 | + print("Usage: dataframe_example.py <libsvm file>", file=sys.stderr) |
46 | 36 | exit(-1)
|
47 |
| - sc = SparkContext(appName="DatasetExample") |
| 37 | + sc = SparkContext(appName="DataFrameExample") |
48 | 38 | sqlContext = SQLContext(sc)
|
49 | 39 | if len(sys.argv) == 2:
|
50 | 40 | input = sys.argv[1]
|
51 | 41 | else:
|
52 | 42 | input = "data/mllib/sample_libsvm_data.txt"
|
53 |
| - points = MLUtils.loadLibSVMFile(sc, input) |
54 |
| - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() |
55 |
| - summarize(dataset0) |
| 43 | + |
| 44 | + # Load input data |
| 45 | + print("Loading LIBSVM file with UDT from " + input + ".") |
| 46 | + df = sqlContext.read.format("libsvm").load(input).cache() |
| 47 | + print("Schema from LIBSVM:") |
| 48 | + df.printSchema() |
| 49 | + print("Loaded training data as a DataFrame with " + |
| 50 | + str(df.count()) + " records.") |
| 51 | + |
| 52 | + # Show statistical summary of labels. |
| 53 | + labelSummary = df.describe("label") |
| 54 | + labelSummary.show() |
| 55 | + |
| 56 | + # Convert features column to an RDD of vectors. |
| 57 | + features = df.select("features").map(lambda r: r.features) |
| 58 | + summary = Statistics.colStats(features) |
| 59 | + print("Selected features column with average values:\n" + |
| 60 | + str(summary.mean())) |
| 61 | + |
| 62 | + # Save the records in a parquet file. |
56 | 63 | tempdir = tempfile.NamedTemporaryFile(delete=False).name
|
57 | 64 | os.unlink(tempdir)
|
58 |
| - print("Save dataset as a Parquet file to %s." % tempdir) |
59 |
| - dataset0.saveAsParquetFile(tempdir) |
60 |
| - print("Load it back and summarize it again.") |
61 |
| - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() |
62 |
| - summarize(dataset1) |
| 65 | + print("Saving to " + tempdir + " as Parquet file.") |
| 66 | + df.write.parquet(tempdir) |
| 67 | + |
| 68 | + # Load the records back. |
| 69 | + print("Loading Parquet file with UDT from " + tempdir) |
| 70 | + newDF = sqlContext.read.parquet(tempdir) |
| 71 | + print("Schema from Parquet:") |
| 72 | + newDF.printSchema() |
63 | 73 | shutil.rmtree(tempdir)
|
| 74 | + |
| 75 | + sc.stop() |
0 commit comments