Skip to content

Commit 99f88c2

Browse files
committed
Fixed bug in PipelineModel.transform* with usage of params. Updated CrossValidatorExample to use more training examples so it is less likely to get a 0-size fold.
1 parent ea34dc6 commit 99f88c2

File tree

4 files changed

+34
-22
lines changed

4 files changed

+34
-22
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.examples.ml;
1919

20-
import java.util.ArrayList;
2120
import java.util.List;
2221

2322
import com.google.common.collect.Lists;
@@ -28,7 +27,6 @@
2827
import org.apache.spark.ml.Pipeline;
2928
import org.apache.spark.ml.PipelineStage;
3029
import org.apache.spark.ml.classification.LogisticRegression;
31-
import org.apache.spark.ml.classification.LogisticRegressionModel;
3230
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
3331
import org.apache.spark.ml.feature.HashingTF;
3432
import org.apache.spark.ml.feature.Tokenizer;
@@ -65,7 +63,15 @@ public static void main(String[] args) {
6563
new LabeledDocument(0L, "a b c d e spark", 1.0),
6664
new LabeledDocument(1L, "b d", 0.0),
6765
new LabeledDocument(2L, "spark f g h", 1.0),
68-
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
66+
new LabeledDocument(3L, "hadoop mapreduce", 0.0),
67+
new LabeledDocument(4L, "b spark who", 1.0),
68+
new LabeledDocument(5L, "g d a y", 0.0),
69+
new LabeledDocument(6L, "spark fly", 1.0),
70+
new LabeledDocument(7L, "was mapreduce", 0.0),
71+
new LabeledDocument(8L, "e spark program", 1.0),
72+
new LabeledDocument(9L, "a e c l", 0.0),
73+
new LabeledDocument(10L, "spark compile", 1.0),
74+
new LabeledDocument(11L, "hadoop software", 0.0));
6975
JavaSchemaRDD training =
7076
jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
7177

@@ -112,8 +118,8 @@ public static void main(String[] args) {
112118
new Document(7L, "apache hadoop"));
113119
JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
114120

115-
// Make predictions on test documents.
116-
lrModel.transform(test).registerAsTable("prediction");
121+
// Make predictions on test documents. cvModel uses the best model found (lrModel).
122+
cvModel.transform(test).registerAsTable("prediction");
117123
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
118124
for (Row r: predictions.collect()) {
119125
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)

examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.examples.ml
1919

2020
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.SparkContext._
2122
import org.apache.spark.ml.Pipeline
2223
import org.apache.spark.ml.classification.LogisticRegression
2324
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
@@ -50,7 +51,15 @@ object CrossValidatorExample {
5051
LabeledDocument(0L, "a b c d e spark", 1.0),
5152
LabeledDocument(1L, "b d", 0.0),
5253
LabeledDocument(2L, "spark f g h", 1.0),
53-
LabeledDocument(3L, "hadoop mapreduce", 0.0)))
54+
LabeledDocument(3L, "hadoop mapreduce", 0.0),
55+
LabeledDocument(4L, "b spark who", 1.0),
56+
LabeledDocument(5L, "g d a y", 0.0),
57+
LabeledDocument(6L, "spark fly", 1.0),
58+
LabeledDocument(7L, "was mapreduce", 0.0),
59+
LabeledDocument(8L, "e spark program", 1.0),
60+
LabeledDocument(9L, "a e c l", 0.0),
61+
LabeledDocument(10L, "spark compile", 1.0),
62+
LabeledDocument(11L, "hadoop software", 0.0)))
5463

5564
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
5665
val tokenizer = new Tokenizer()
@@ -81,16 +90,7 @@ object CrossValidatorExample {
8190
crossval.setNumFolds(2)
8291

8392
// Run cross-validation, and choose the best set of parameters.
84-
val cvModel = try {
85-
crossval.fit(training)
86-
} catch {
87-
case e: Exception =>
88-
println("\nSTACK TRACE\n")
89-
println(e.getStackTraceString)
90-
println("\nSTACK TRACE OF CAUSE\n")
91-
println(e.getCause.getStackTraceString)
92-
throw e
93-
}
93+
val cvModel = crossval.fit(training)
9494
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
9595
val lrModel = cvModel.bestModel
9696

@@ -101,8 +101,8 @@ object CrossValidatorExample {
101101
Document(6L, "mapreduce spark"),
102102
Document(7L, "apache hadoop")))
103103

104-
// Make predictions on test documents using the best LogisticRegression model.
105-
lrModel.transform(test)
104+
// Make predictions on test documents. cvModel uses the best model found (lrModel).
105+
cvModel.transform(test)
106106
.select('id, 'text, 'score, 'prediction)
107107
.collect()
108108
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,15 @@ class PipelineModel private[ml] (
162162
}
163163

164164
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
165-
transformSchema(dataset.schema, paramMap, logging = true)
166-
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
165+
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
166+
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
167+
transformSchema(dataset.schema, map, logging = true)
168+
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
167169
}
168170

169171
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
170-
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
172+
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
173+
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
174+
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
171175
}
172176
}

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ private[spark] object BLAS extends Serializable with Logging {
9292
* dot(x, y)
9393
*/
9494
def dot(x: Vector, y: Vector): Double = {
95-
require(x.size == y.size)
95+
require(x.size == y.size,
96+
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
97+
" x.size = " + x.size + ", y.size = " + y.size)
9698
(x, y) match {
9799
case (dx: DenseVector, dy: DenseVector) =>
98100
dot(dx, dy)

0 commit comments

Comments
 (0)