Skip to content

Commit 4a1bcb2

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-11723][ML][DOC] Use LibSVM data source rather than MLUtils.loadLibSVMFile to load DataFrame
Use LibSVM data source rather than MLUtils.loadLibSVMFile to load DataFrame, include: * Use libSVM data source for all example codes under examples/ml, and remove unused import. * Use libSVM data source for user guides under ml-*** which were omitted by #8697. * Fix bug: We should use ```sqlContext.read().format("libsvm").load(path)``` at Java side, but the API doc and user guides misuse as ```sqlContext.read.format("libsvm").load(path)```. * Code cleanup. mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #9690 from yanboliang/spark-11723. (cherry picked from commit 99693fe) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 98c614d commit 4a1bcb2

26 files changed

+79
-130
lines changed

docs/ml-ensembles.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ import org.apache.spark.ml.feature.*;
195195
import org.apache.spark.sql.DataFrame;
196196

197197
// Load and parse the data file, converting it to a DataFrame.
198-
DataFrame data = sqlContext.read.format("libsvm")
198+
DataFrame data = sqlContext.read().format("libsvm")
199199
.load("data/mllib/sample_libsvm_data.txt");
200200

201201
// Index labels, adding metadata to the label column.
@@ -384,7 +384,7 @@ import org.apache.spark.ml.regression.RandomForestRegressor;
384384
import org.apache.spark.sql.DataFrame;
385385

386386
// Load and parse the data file, converting it to a DataFrame.
387-
DataFrame data = sqlContext.read.format("libsvm")
387+
DataFrame data = sqlContext.read().format("libsvm")
388388
.load("data/mllib/sample_libsvm_data.txt");
389389

390390
// Automatically identify categorical features, and index them.
@@ -640,7 +640,7 @@ import org.apache.spark.ml.feature.*;
640640
import org.apache.spark.sql.DataFrame;
641641

642642
// Load and parse the data file, converting it to a DataFrame.
643-
DataFrame data sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt");
643+
DataFrame data sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
644644

645645
// Index labels, adding metadata to the label column.
646646
// Fit on whole dataset to include all labels in index.
@@ -830,7 +830,7 @@ import org.apache.spark.ml.regression.GBTRegressor;
830830
import org.apache.spark.sql.DataFrame;
831831

832832
// Load and parse the data file, converting it to a DataFrame.
833-
DataFrame data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt");
833+
DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
834834

835835
// Automatically identify categorical features, and index them.
836836
// Set maxCategories so features with > 4 distinct values are treated as continuous.
@@ -1000,7 +1000,7 @@ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
10001000
JavaSparkContext jsc = new JavaSparkContext(conf);
10011001
SQLContext jsql = new SQLContext(jsc);
10021002

1003-
DataFrame dataFrame = sqlContext.read.format("libsvm")
1003+
DataFrame dataFrame = sqlContext.read().format("libsvm")
10041004
.load("data/mllib/sample_multiclass_classification_data.txt");
10051005

10061006
DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345);

docs/ml-features.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ import org.apache.spark.ml.feature.VectorIndexer;
11091109
import org.apache.spark.ml.feature.VectorIndexerModel;
11101110
import org.apache.spark.sql.DataFrame;
11111111

1112-
DataFrame data = sqlContext.read.format("libsvm")
1112+
DataFrame data = sqlContext.read().format("libsvm")
11131113
.load("data/mllib/sample_libsvm_data.txt");
11141114
VectorIndexer indexer = new VectorIndexer()
11151115
.setInputCol("features")
@@ -1187,7 +1187,7 @@ for more details on the API.
11871187
import org.apache.spark.ml.feature.Normalizer;
11881188
import org.apache.spark.sql.DataFrame;
11891189

1190-
DataFrame dataFrame = sqlContext.read.format("libsvm")
1190+
DataFrame dataFrame = sqlContext.read().format("libsvm")
11911191
.load("data/mllib/sample_libsvm_data.txt");
11921192

11931193
// Normalize each Vector using $L^1$ norm.
@@ -1273,7 +1273,7 @@ import org.apache.spark.ml.feature.StandardScaler;
12731273
import org.apache.spark.ml.feature.StandardScalerModel;
12741274
import org.apache.spark.sql.DataFrame;
12751275

1276-
DataFrame dataFrame = sqlContext.read.format("libsvm")
1276+
DataFrame dataFrame = sqlContext.read().format("libsvm")
12771277
.load("data/mllib/sample_libsvm_data.txt");
12781278
StandardScaler scaler = new StandardScaler()
12791279
.setInputCol("features")
@@ -1366,7 +1366,7 @@ import org.apache.spark.ml.feature.MinMaxScaler;
13661366
import org.apache.spark.ml.feature.MinMaxScalerModel;
13671367
import org.apache.spark.sql.DataFrame;
13681368

1369-
DataFrame dataFrame = sqlContext.read.format("libsvm")
1369+
DataFrame dataFrame = sqlContext.read().format("libsvm")
13701370
.load("data/mllib/sample_libsvm_data.txt");
13711371
MinMaxScaler scaler = new MinMaxScaler()
13721372
.setInputCol("features")

docs/ml-guide.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -867,10 +867,9 @@ The `ParamMap` which produces the best evaluation metric is selected as the best
867867
import org.apache.spark.ml.evaluation.RegressionEvaluator
868868
import org.apache.spark.ml.regression.LinearRegression
869869
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
870-
import org.apache.spark.mllib.util.MLUtils
871870

872871
// Prepare training and test data.
873-
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
872+
val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
874873
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
875874

876875
val lr = new LinearRegression()
@@ -911,14 +910,9 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator;
911910
import org.apache.spark.ml.param.ParamMap;
912911
import org.apache.spark.ml.regression.LinearRegression;
913912
import org.apache.spark.ml.tuning.*;
914-
import org.apache.spark.mllib.regression.LabeledPoint;
915-
import org.apache.spark.mllib.util.MLUtils;
916-
import org.apache.spark.rdd.RDD;
917913
import org.apache.spark.sql.DataFrame;
918914

919-
DataFrame data = sqlContext.createDataFrame(
920-
MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
921-
LabeledPoint.class);
915+
DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
922916

923917
// Prepare training and test data.
924918
DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345);

docs/ml-linear-methods.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample {
9595
String path = "data/mllib/sample_libsvm_data.txt";
9696

9797
// Load training data
98-
DataFrame training = sqlContext.read.format("libsvm").load(path);
98+
DataFrame training = sqlContext.read().format("libsvm").load(path);
9999

100100
LogisticRegression lr = new LogisticRegression()
101101
.setMaxIter(10)
@@ -292,7 +292,7 @@ public class LinearRegressionWithElasticNetExample {
292292
String path = "data/mllib/sample_libsvm_data.txt";
293293

294294
// Load training data
295-
DataFrame training = sqlContext.read.format("libsvm").load(path);
295+
DataFrame training = sqlContext.read().format("libsvm").load(path);
296296

297297
LinearRegression lr = new LinearRegression()
298298
.setMaxIter(10)

examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
2727
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
2828
import org.apache.spark.ml.feature.*;
29-
import org.apache.spark.mllib.regression.LabeledPoint;
30-
import org.apache.spark.mllib.util.MLUtils;
31-
import org.apache.spark.rdd.RDD;
3229
import org.apache.spark.sql.DataFrame;
3330
import org.apache.spark.sql.SQLContext;
3431
// $example off$
@@ -40,9 +37,8 @@ public static void main(String[] args) {
4037
SQLContext sqlContext = new SQLContext(jsc);
4138

4239
// $example on$
43-
// Load and parse the data file, converting it to a DataFrame.
44-
RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
45-
DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);
40+
// Load the data stored in LIBSVM format as a DataFrame.
41+
DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
4642

4743
// Index labels, adding metadata to the label column.
4844
// Fit on whole dataset to include all labels in index.

examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
import org.apache.spark.ml.feature.VectorIndexerModel;
2828
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
2929
import org.apache.spark.ml.regression.DecisionTreeRegressor;
30-
import org.apache.spark.mllib.regression.LabeledPoint;
31-
import org.apache.spark.mllib.util.MLUtils;
32-
import org.apache.spark.rdd.RDD;
3330
import org.apache.spark.sql.DataFrame;
3431
import org.apache.spark.sql.SQLContext;
3532
// $example off$
@@ -40,9 +37,9 @@ public static void main(String[] args) {
4037
JavaSparkContext jsc = new JavaSparkContext(conf);
4138
SQLContext sqlContext = new SQLContext(jsc);
4239
// $example on$
43-
// Load and parse the data file, converting it to a DataFrame.
44-
RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
45-
DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);
40+
// Load the data stored in LIBSVM format as a DataFrame.
41+
DataFrame data = sqlContext.read().format("libsvm")
42+
.load("data/mllib/sample_libsvm_data.txt");
4643

4744
// Automatically identify categorical features, and index them.
4845
// Set maxCategories so features with > 4 distinct values are treated as continuous.

examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,9 @@
2121
import org.apache.spark.SparkConf;
2222
import org.apache.spark.api.java.JavaSparkContext;
2323
import org.apache.spark.sql.SQLContext;
24-
import org.apache.spark.api.java.JavaRDD;
2524
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
2625
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
2726
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
28-
import org.apache.spark.mllib.regression.LabeledPoint;
29-
import org.apache.spark.mllib.util.MLUtils;
3027
import org.apache.spark.sql.DataFrame;
3128
// $example off$
3229

@@ -43,8 +40,7 @@ public static void main(String[] args) {
4340
// $example on$
4441
// Load training data
4542
String path = "data/mllib/sample_multiclass_classification_data.txt";
46-
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
47-
DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
43+
DataFrame dataFrame = jsql.read().format("libsvm").load(path);
4844
// Split the data into train and test
4945
DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
5046
DataFrame train = splits[0];

examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
import org.apache.spark.ml.util.MetadataUtils;
2828
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
2929
import org.apache.spark.mllib.linalg.Matrix;
30-
import org.apache.spark.mllib.regression.LabeledPoint;
31-
import org.apache.spark.mllib.util.MLUtils;
32-
import org.apache.spark.rdd.RDD;
30+
import org.apache.spark.mllib.linalg.Vector;
3331
import org.apache.spark.sql.DataFrame;
3432
import org.apache.spark.sql.SQLContext;
3533
import org.apache.spark.sql.types.StructField;
@@ -80,31 +78,30 @@ public static void main(String[] args) {
8078
OneVsRest ovr = new OneVsRest().setClassifier(classifier);
8179

8280
String input = params.input;
83-
RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
84-
RDD<LabeledPoint> train;
85-
RDD<LabeledPoint> test;
81+
DataFrame inputData = jsql.read().format("libsvm").load(input);
82+
DataFrame train;
83+
DataFrame test;
8684

8785
// compute the train/ test split: if testInput is not provided use part of input
8886
String testInput = params.testInput;
8987
if (testInput != null) {
9088
train = inputData;
9189
// compute the number of features in the training set.
92-
int numFeatures = inputData.first().features().size();
93-
test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
90+
int numFeatures = inputData.first().<Vector>getAs(1).size();
91+
test = jsql.read().format("libsvm").option("numFeatures",
92+
String.valueOf(numFeatures)).load(testInput);
9493
} else {
9594
double f = params.fracTest;
96-
RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
95+
DataFrame[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
9796
train = tmp[0];
9897
test = tmp[1];
9998
}
10099

101100
// train the multiclass model
102-
DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
103-
OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
101+
OneVsRestModel ovrModel = ovr.fit(train.cache());
104102

105103
// score the model on test data
106-
DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
107-
DataFrame predictions = ovrModel.transform(testDataFrame.cache())
104+
DataFrame predictions = ovrModel.transform(test.cache())
108105
.select("prediction", "label");
109106

110107
// obtain metrics

examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
import org.apache.spark.ml.param.ParamMap;
2424
import org.apache.spark.ml.regression.LinearRegression;
2525
import org.apache.spark.ml.tuning.*;
26-
import org.apache.spark.mllib.regression.LabeledPoint;
27-
import org.apache.spark.mllib.util.MLUtils;
2826
import org.apache.spark.sql.DataFrame;
2927
import org.apache.spark.sql.SQLContext;
3028

@@ -46,9 +44,7 @@ public static void main(String[] args) {
4644
JavaSparkContext jsc = new JavaSparkContext(conf);
4745
SQLContext jsql = new SQLContext(jsc);
4846

49-
DataFrame data = jsql.createDataFrame(
50-
MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
51-
LabeledPoint.class);
47+
DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
5248

5349
// Prepare training and test data.
5450
DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);

examples/src/main/python/ml/decision_tree_classification_example.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@
2828
from pyspark.ml.classification import DecisionTreeClassifier
2929
from pyspark.ml.feature import StringIndexer, VectorIndexer
3030
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
31-
from pyspark.mllib.util import MLUtils
3231
# $example off$
3332

3433
if __name__ == "__main__":
3534
sc = SparkContext(appName="decision_tree_classification_example")
3635
sqlContext = SQLContext(sc)
3736

3837
# $example on$
39-
# Load and parse the data file, converting it to a DataFrame.
40-
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
38+
# Load the data stored in LIBSVM format as a DataFrame.
39+
data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
4140

4241
# Index labels, adding metadata to the label column.
4342
# Fit on whole dataset to include all labels in index.

0 commit comments

Comments
 (0)