17
17
18
18
package org .apache .spark .ml .classification ;
19
19
20
+ import scala .Tuple2 ;
21
+
20
22
import java .io .Serializable ;
23
+ import java .lang .Math ;
24
+ import java .util .ArrayList ;
21
25
import java .util .List ;
22
26
23
27
import org .junit .After ;
24
28
import org .junit .Before ;
25
29
import org .junit .Test ;
26
30
31
+ import org .apache .spark .api .java .JavaRDD ;
27
32
import org .apache .spark .api .java .JavaSparkContext ;
28
33
import org .apache .spark .mllib .regression .LabeledPoint ;
29
34
import org .apache .spark .sql .DataFrame ;
30
35
import org .apache .spark .sql .SQLContext ;
31
36
import static org .apache .spark .mllib .classification .LogisticRegressionSuite .generateLogisticInputAsList ;
37
+ import org .apache .spark .api .java .function .Function ;
38
+ import org .apache .spark .mllib .linalg .Vector ;
39
+ import org .apache .spark .ml .LabeledPoint ;
40
+ import org .apache .spark .sql .Row ;
41
+
32
42
33
43
public class JavaLogisticRegressionSuite implements Serializable {
34
44
35
45
private transient JavaSparkContext jsc ;
36
46
private transient SQLContext jsql ;
37
47
private transient DataFrame dataset ;
38
48
49
+ private transient JavaRDD <LabeledPoint > datasetRDD ;
50
+ private transient JavaRDD <Vector > featuresRDD ;
51
+ private double eps = 1e-5 ;
52
+
39
53
@ Before
40
54
public void setUp () {
41
55
jsc = new JavaSparkContext ("local" , "JavaLogisticRegressionSuite" );
42
56
jsql = new SQLContext (jsc );
43
- List <LabeledPoint > points = generateLogisticInputAsList (1.0 , 1.0 , 100 , 42 );
44
- dataset = jsql .applySchema (jsc .parallelize (points , 2 ), LabeledPoint .class );
57
+ List <LabeledPoint > points = new ArrayList <LabeledPoint >();
58
+ for (org .apache .spark .mllib .regression .LabeledPoint lp :
59
+ generateLogisticInputAsList (1.0 , 1.0 , 100 , 42 )) {
60
+ points .add (new LabeledPoint (lp .label (), lp .features ()));
61
+ }
62
+ datasetRDD = jsc .parallelize (points , 2 );
63
+ featuresRDD = datasetRDD .map (new Function <LabeledPoint , Vector >() {
64
+ @ Override public Vector call (LabeledPoint lp ) { return lp .features (); }
65
+ });
66
+ dataset = jsql .applySchema (datasetRDD , LabeledPoint .class );
67
+ dataset .registerTempTable ("dataset" );
45
68
}
46
69
47
70
@ After
@@ -51,29 +74,112 @@ public void tearDown() {
51
74
}
52
75
53
76
@ Test
54
- public void logisticRegression () {
77
+ public void logisticRegressionDefaultParams () {
55
78
LogisticRegression lr = new LogisticRegression ();
79
+ assert (lr .getLabelCol ().equals ("label" ));
56
80
LogisticRegressionModel model = lr .fit (dataset );
57
81
model .transform (dataset ).registerTempTable ("prediction" );
58
82
DataFrame predictions = jsql .sql ("SELECT label, score, prediction FROM prediction" );
59
83
predictions .collectAsList ();
84
+ // Check defaults
85
+ assert (model .getThreshold () == 0.5 );
86
+ assert (model .getFeaturesCol ().equals ("features" ));
87
+ assert (model .getPredictionCol ().equals ("prediction" ));
88
+ assert (model .getScoreCol ().equals ("score" ));
60
89
}
61
90
62
91
@ Test
63
92
public void logisticRegressionWithSetters () {
93
+ // Set params, train, and check as many params as we can.
64
94
LogisticRegression lr = new LogisticRegression ()
65
95
.setMaxIter (10 )
66
- .setRegParam (1.0 );
96
+ .setRegParam (1.0 )
97
+ .setThreshold (0.6 )
98
+ .setScoreCol ("probability" );
67
99
LogisticRegressionModel model = lr .fit (dataset );
100
+ assert (model .fittingParamMap ().get (lr .maxIter ()).get () == 10 );
101
+ assert (model .fittingParamMap ().get (lr .regParam ()).get () == 1.0 );
102
+ assert (model .fittingParamMap ().get (lr .threshold ()).get () == 0.6 );
103
+ assert (model .getThreshold () == 0.6 );
104
+
105
+ // Modify model params, and check that the params worked.
106
+ model .setThreshold (1.0 );
107
+ model .transform (dataset ).registerTempTable ("predAllZero" );
108
+ SchemaRDD predAllZero = jsql .sql ("SELECT prediction, probability FROM predAllZero" );
109
+ for (Row r : predAllZero .collectAsList ()) {
110
+ assert (r .getDouble (0 ) == 0.0 );
111
+ }
112
+ // Call transform with params, and check that the params worked.
113
+ /* TODO: USE THIS
68
114
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
69
- .registerTempTable ("prediction" );
115
+ .registerTempTable("prediction");
70
116
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
71
117
predictions.collectAsList();
118
+ */
119
+
120
+ model .transform (dataset , model .threshold ().w (0.0 ), model .scoreCol ().w ("myProb" ))
121
+ .registerTempTable ("predNotAllZero" );
122
+ SchemaRDD predNotAllZero = jsql .sql ("SELECT prediction, myProb FROM predNotAllZero" );
123
+ boolean foundNonZero = false ;
124
+ for (Row r : predNotAllZero .collectAsList ()) {
125
+ if (r .getDouble (0 ) != 0.0 ) foundNonZero = true ;
126
+ }
127
+ assert (foundNonZero );
128
+
129
+ // Call fit() with new params, and check as many params as we can.
130
+ LogisticRegressionModel model2 = lr .fit (dataset , lr .maxIter ().w (5 ), lr .regParam ().w (0.1 ),
131
+ lr .threshold ().w (0.4 ), lr .scoreCol ().w ("theProb" ));
132
+ assert (model2 .fittingParamMap ().get (lr .maxIter ()).get () == 5 );
133
+ assert (model2 .fittingParamMap ().get (lr .regParam ()).get () == 0.1 );
134
+ assert (model2 .fittingParamMap ().get (lr .threshold ()).get () == 0.4 );
135
+ assert (model2 .getThreshold () == 0.4 );
136
+ assert (model2 .getScoreCol ().equals ("theProb" ));
72
137
}
73
138
74
139
@ Test
75
- public void logisticRegressionFitWithVarargs () {
140
+ public void logisticRegressionPredictorClassifierMethods () {
76
141
LogisticRegression lr = new LogisticRegression ();
77
- lr .fit (dataset , lr .maxIter ().w (10 ), lr .regParam ().w (1.0 ));
142
+
143
+ // fit() vs. train()
144
+ LogisticRegressionModel model1 = lr .fit (dataset );
145
+ LogisticRegressionModel model2 = lr .train (datasetRDD );
146
+ assert (model1 .intercept () == model2 .intercept ());
147
+ assert (model1 .weights ().equals (model2 .weights ()));
148
+ assert (model1 .numClasses () == model2 .numClasses ());
149
+ assert (model1 .numClasses () == 2 );
150
+
151
+ // transform() vs. predict()
152
+ model1 .transform (dataset ).registerTempTable ("transformed" );
153
+ SchemaRDD trans = jsql .sql ("SELECT prediction FROM transformed" );
154
+ JavaRDD <Double > preds = model1 .predict (featuresRDD );
155
+ for (scala .Tuple2 <Row , Double > trans_pred : trans .toJavaRDD ().zip (preds ).collect ()) {
156
+ double t = trans_pred ._1 ().getDouble (0 );
157
+ double p = trans_pred ._2 ();
158
+ assert (t == p );
159
+ }
160
+
161
+ // Check various types of predictions.
162
+ JavaRDD <Vector > rawPredictions = model1 .predictRaw (featuresRDD );
163
+ JavaRDD <Vector > probabilities = model1 .predictProbabilities (featuresRDD );
164
+ JavaRDD <Double > predictions = model1 .predict (featuresRDD );
165
+ double threshold = model1 .getThreshold ();
166
+ for (Tuple2 <Vector , Vector > raw_prob : rawPredictions .zip (probabilities ).collect ()) {
167
+ Vector raw = raw_prob ._1 ();
168
+ Vector prob = raw_prob ._2 ();
169
+ for (int i = 0 ; i < raw .size (); ++i ) {
170
+ double r = raw .apply (i );
171
+ double p = prob .apply (i );
172
+ double pFromR = 1.0 / (1.0 + Math .exp (-r ));
173
+ assert (Math .abs (r - pFromR ) < eps );
174
+ }
175
+ }
176
+ for (Tuple2 <Vector , Double > prob_pred : probabilities .zip (predictions ).collect ()) {
177
+ Vector prob = prob_pred ._1 ();
178
+ double pred = prob_pred ._2 ();
179
+ double probOfPred = prob .apply ((int )pred );
180
+ for (int i = 0 ; i < prob .size (); ++i ) {
181
+ assert (probOfPred >= prob .apply (i ));
182
+ }
183
+ }
78
184
}
79
185
}
0 commit comments