Skip to content

Commit fe6dcfe

Browse files
committed
Added several Java-friendly APIs + unit tests: NaiveBayes, GaussianMixture, LDA, StreamingKMeans, Statistics.corr, params
1 parent 3c01568 commit fe6dcfe

File tree

14 files changed

+276
-16
lines changed

14 files changed

+276
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
6969
}
7070
}
7171

72-
/**
73-
* Creates a param pair with the given value (for Java).
74-
*/
72+
/** Creates a param pair with the given value (for Java). */
7573
def w(value: T): ParamPair[T] = this -> value
7674

77-
/**
78-
* Creates a param pair with the given value (for Scala).
79-
*/
75+
/** Creates a param pair with the given value (for Scala). */
8076
def ->(value: T): ParamPair[T] = ParamPair(this, value)
8177

8278
override final def toString: String = s"${parent}__$name"
@@ -190,6 +186,7 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
190186

191187
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
192188

189+
/** Creates a param pair with the given value (for Java). */
193190
override def w(value: Double): ParamPair[Double] = super.w(value)
194191
}
195192

@@ -209,6 +206,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
209206

210207
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
211208

209+
/** Creates a param pair with the given value (for Java). */
212210
override def w(value: Int): ParamPair[Int] = super.w(value)
213211
}
214212

@@ -228,6 +226,7 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
228226

229227
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
230228

229+
/** Creates a param pair with the given value (for Java). */
231230
override def w(value: Float): ParamPair[Float] = super.w(value)
232231
}
233232

@@ -247,6 +246,7 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
247246

248247
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
249248

249+
/** Creates a param pair with the given value (for Java). */
250250
override def w(value: Long): ParamPair[Long] = super.w(value)
251251
}
252252

@@ -260,6 +260,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
260260

261261
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
262262

263+
/** Creates a param pair with the given value (for Java). */
263264
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
264265
}
265266

@@ -274,8 +275,6 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
274275
def this(parent: Params, name: String, doc: String) =
275276
this(parent, name, doc, ParamValidators.alwaysTrue)
276277

277-
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
278-
279278
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
280279
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
281280
}
@@ -291,10 +290,9 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
291290
def this(parent: Params, name: String, doc: String) =
292291
this(parent, name, doc, ParamValidators.alwaysTrue)
293292

294-
override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)
295-
296293
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
297-
def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
294+
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
295+
w(value.asScala.map(_.asInstanceOf[Double]).toArray)
298296
}
299297

300298
/**

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
2424
import org.json4s.JsonDSL._
2525
import org.json4s.jackson.JsonMethods._
2626

27+
import org.apache.spark.api.java.JavaRDD
2728
import org.apache.spark.{Logging, SparkContext, SparkException}
2829
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
2930
import org.apache.spark.mllib.regression.LabeledPoint

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
2222
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
2323

2424
import org.apache.spark.annotation.Experimental
25+
import org.apache.spark.api.java.JavaRDD
2526
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
2627
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
2728
import org.apache.spark.mllib.util.MLUtils
@@ -188,7 +189,10 @@ class GaussianMixture private (
188189
new GaussianMixtureModel(weights, gaussians)
189190
}
190191

191-
/** Average of dense breeze vectors */
192+
/** Java-friendly version of [[run()]] */
193+
def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
194+
195+
/** Average of dense breeze vectors */
192196
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
193197
val v = BDV.zeros[Double](x(0).length)
194198
x.foreach(xi => v += xi)

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
2525

2626
import org.apache.spark.SparkContext
2727
import org.apache.spark.annotation.Experimental
28+
import org.apache.spark.api.java.JavaRDD
2829
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
2930
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
3031
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
@@ -46,7 +47,7 @@ import org.apache.spark.sql.{SQLContext, Row}
4647
@Experimental
4748
class GaussianMixtureModel(
4849
val weights: Array[Double],
49-
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
50+
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
5051

5152
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
5253

@@ -65,6 +66,10 @@ class GaussianMixtureModel(
6566
responsibilityMatrix.map(r => r.indexOf(r.max))
6667
}
6768

69+
/** Java-friendly version of [[predict()]] */
70+
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
71+
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
72+
6873
/**
6974
* Given the input vectors, return the membership value of each vector
7075
* to all mixture components.

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
2020
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
2121

2222
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.api.java.JavaPairRDD
2324
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
2425
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
2526
import org.apache.spark.rdd.RDD
@@ -345,6 +346,12 @@ class DistributedLDAModel private (
345346
}
346347
}
347348

349+
/** Java-friendly version of [[topicDistributions]] */
350+
def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
351+
new JavaPairRDD[java.lang.Long, Vector](
352+
topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
353+
}
354+
348355
// TODO:
349356
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
350357

mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ import scala.reflect.ClassTag
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.Experimental
24+
import org.apache.spark.api.java.JavaSparkContext._
2425
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
2526
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
2628
import org.apache.spark.streaming.dstream.DStream
2729
import org.apache.spark.util.Utils
2830
import org.apache.spark.util.random.XORShiftRandom
@@ -234,6 +236,9 @@ class StreamingKMeans(
234236
}
235237
}
236238

239+
/** Java-friendly version of `trainOn`. */
240+
def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)
241+
237242
/**
238243
* Use the clustering model to make predictions on batches of data from a DStream.
239244
*
@@ -245,6 +250,11 @@ class StreamingKMeans(
245250
data.map(model.predict)
246251
}
247252

253+
/** Java-friendly version of `predictOn`. */
254+
def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
255+
JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
256+
}
257+
248258
/**
249259
* Use the model to make predictions on the values of a DStream and carry over its keys.
250260
*
@@ -257,6 +267,14 @@ class StreamingKMeans(
257267
data.mapValues(model.predict)
258268
}
259269

270+
/** Java-friendly version of `predictOnValues`. */
271+
def predictOnValues[K](
272+
data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
273+
implicit val tag = fakeClassTag[K]
274+
JavaPairDStream.fromPairDStream(
275+
predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
276+
}
277+
260278
/** Check whether cluster centers have been initialized. */
261279
private[this] def assertInitialized(): Unit = {
262280
if (model.clusterCenters == null) {

mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.mllib.stat
1919

2020
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.api.java.JavaRDD
2122
import org.apache.spark.mllib.linalg.distributed.RowMatrix
2223
import org.apache.spark.mllib.linalg.{Matrix, Vector}
2324
import org.apache.spark.mllib.regression.LabeledPoint
@@ -80,6 +81,10 @@ object Statistics {
8081
*/
8182
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
8283

84+
/** Java-friendly version of [[corr()]] */
85+
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
86+
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
87+
8388
/**
8489
* Compute the correlation for the input RDDs using the specified method.
8590
* Methods currently supported: `pearson` (default), `spearman`.
@@ -96,6 +101,9 @@ object Statistics {
96101
*/
97102
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
98103

104+
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
105+
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
106+
99107
/**
100108
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
101109
* expected distribution.

mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public void testParams() {
5050
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
5151
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
5252
Assert.assertEquals(testParams.getMyStringParam(), "a");
53+
Assert.assertEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0});
5354
}
5455

5556
@Test

mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,31 @@ public JavaTestParams setMyStringParam(String value) {
7272
set(myStringParam_, value); return this;
7373
}
7474

75+
private DoubleArrayParam myDoubleArrayParam_;
76+
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
77+
78+
public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
79+
80+
public JavaTestParams setMyDoubleArrayParam(double[] value) {
81+
set(myDoubleArrayParam_, value); return this;
82+
}
83+
7584
private void init() {
76-
myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
85+
myIntParam_ = new IntParam(this, "myIntParam", "this is an int param",
86+
ParamValidators.gt(0));
7787
myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
7888
ParamValidators.inRange(0.0, 1.0));
7989
List<String> validStrings = Lists.newArrayList("a", "b");
8090
myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
8191
ParamValidators.inArray(validStrings));
92+
myDoubleArrayParam_ =
93+
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
94+
8295
setDefault(myIntParam_, 1);
96+
setDefault(myIntParam_.w(1));
8397
setDefault(myDoubleParam_, 0.5);
8498
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
99+
setDefault(myDoubleArrayParam_, new double[]{1.0, 2.0});
100+
setDefault(myDoubleArrayParam_.w(new double[]{1.0, 2.0}));
85101
}
86102
}

mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java renamed to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.ml.classification;
18+
package org.apache.spark.mllib.classification;
1919

2020
import java.io.Serializable;
2121
import java.util.List;
@@ -28,7 +28,6 @@
2828
import org.junit.Test;
2929

3030
import org.apache.spark.SparkConf;
31-
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
3231
import org.apache.spark.mllib.linalg.Vector;
3332
import org.apache.spark.mllib.linalg.Vectors;
3433
import org.apache.spark.mllib.regression.LabeledPoint;
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering;
19+
20+
import java.io.Serializable;
21+
import java.util.List;
22+
23+
import com.google.common.collect.Lists;
24+
import org.junit.After;
25+
import org.junit.Before;
26+
import org.junit.Test;
27+
28+
import static org.junit.Assert.assertEquals;
29+
30+
import org.apache.spark.api.java.JavaRDD;
31+
import org.apache.spark.api.java.JavaSparkContext;
32+
import org.apache.spark.mllib.linalg.Vector;
33+
import org.apache.spark.mllib.linalg.Vectors;
34+
35+
public class JavaGaussianMixtureSuite implements Serializable {
36+
private transient JavaSparkContext sc;
37+
38+
@Before
39+
public void setUp() {
40+
sc = new JavaSparkContext("local", "JavaGaussianMixture");
41+
}
42+
43+
@After
44+
public void tearDown() {
45+
sc.stop();
46+
sc = null;
47+
}
48+
49+
@Test
50+
public void runGaussianMixture() {
51+
List<Vector> points = Lists.newArrayList(
52+
Vectors.dense(1.0, 2.0, 6.0),
53+
Vectors.dense(1.0, 3.0, 0.0),
54+
Vectors.dense(1.0, 4.0, 6.0)
55+
);
56+
57+
JavaRDD<Vector> data = sc.parallelize(points, 2);
58+
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
59+
.run(data);
60+
assertEquals(model.gaussians().length, 2);
61+
JavaRDD<Integer> predictions = model.predict(data);
62+
predictions.first();
63+
}
64+
}

mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ public void distributedLDAModel() {
107107
// Check: log probabilities
108108
assert(model.logLikelihood() < 0.0);
109109
assert(model.logPrior() < 0.0);
110+
111+
// Check: topic distributions
112+
JavaPairRDD<Long, Vector> topicDistributions = model.javaTopicDistributions();
113+
assertEquals(topicDistributions.count(), corpus.count());
110114
}
111115

112116
@Test

0 commit comments

Comments
 (0)