Skip to content

Commit 0793ee1

Browse files
sryzasrowen
authored andcommitted
SPARK-2149. [MLLIB] Univariate kernel density estimation
Author: Sandy Ryza <sandy@cloudera.com> Closes #1093 from sryza/sandy-spark-2149 and squashes the following commits: 5f06b33 [Sandy Ryza] More review comments 0f73060 [Sandy Ryza] Respond to Sean's review comments 0dfa005 [Sandy Ryza] SPARK-2149. Univariate kernel density estimation
1 parent 4dfe180 commit 0793ee1

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.stat
19+
20+
import org.apache.spark.rdd.RDD
21+
22+
private[stat] object KernelDensity {
23+
/**
24+
* Given a set of samples from a distribution, estimates its density at the set of given points.
25+
* Uses a Gaussian kernel with the given standard deviation.
26+
*/
27+
def estimate(samples: RDD[Double], standardDeviation: Double,
28+
evaluationPoints: Array[Double]): Array[Double] = {
29+
if (standardDeviation <= 0.0) {
30+
throw new IllegalArgumentException("Standard deviation must be positive")
31+
}
32+
33+
// This gets used in each Gaussian PDF computation, so compute it up front
34+
val logStandardDeviationPlusHalfLog2Pi =
35+
Math.log(standardDeviation) + 0.5 * Math.log(2 * Math.PI)
36+
37+
val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))(
38+
(x, y) => {
39+
var i = 0
40+
while (i < evaluationPoints.length) {
41+
x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi,
42+
evaluationPoints(i))
43+
i += 1
44+
}
45+
(x._1, i)
46+
},
47+
(x, y) => {
48+
var i = 0
49+
while (i < evaluationPoints.length) {
50+
x._1(i) += y._1(i)
51+
i += 1
52+
}
53+
(x._1, x._2 + y._2)
54+
})
55+
56+
var i = 0
57+
while (i < points.length) {
58+
points(i) /= count
59+
i += 1
60+
}
61+
points
62+
}
63+
64+
private def normPdf(mean: Double, standardDeviation: Double,
65+
logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = {
66+
val x0 = x - mean
67+
val x1 = x0 / standardDeviation
68+
val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi
69+
Math.exp(logDensity)
70+
}
71+
}

mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,18 @@ object Statistics {
149149
def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
150150
ChiSqTest.chiSquaredFeatures(data)
151151
}
152+
153+
/**
154+
* Given an empirical distribution defined by the input RDD of samples, estimate its density at
155+
* each of the given evaluation points using a Gaussian kernel.
156+
*
157+
* @param samples The samples RDD used to define the empirical distribution.
158+
* @param standardDeviation The standard deviation of the kernel Gaussians.
159+
* @param evaluationPoints The points at which to estimate densities.
160+
* @return An array the same size as evaluationPoints with the density at each point.
161+
*/
162+
def kernelDensity(samples: RDD[Double], standardDeviation: Double,
163+
evaluationPoints: Iterable[Double]): Array[Double] = {
164+
KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray)
165+
}
152166
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.stat
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.commons.math3.distribution.NormalDistribution
23+
24+
import org.apache.spark.mllib.util.LocalClusterSparkContext
25+
26+
class KernelDensitySuite extends FunSuite with LocalClusterSparkContext {
27+
test("kernel density single sample") {
28+
val rdd = sc.parallelize(Array(5.0))
29+
val evaluationPoints = Array(5.0, 6.0)
30+
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
31+
val normal = new NormalDistribution(5.0, 3.0)
32+
val acceptableErr = 1e-6
33+
assert(densities(0) - normal.density(5.0) < acceptableErr)
34+
assert(densities(0) - normal.density(6.0) < acceptableErr)
35+
}
36+
37+
test("kernel density multiple samples") {
38+
val rdd = sc.parallelize(Array(5.0, 10.0))
39+
val evaluationPoints = Array(5.0, 6.0)
40+
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
41+
val normal1 = new NormalDistribution(5.0, 3.0)
42+
val normal2 = new NormalDistribution(10.0, 3.0)
43+
val acceptableErr = 1e-6
44+
assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr)
45+
assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr)
46+
}
47+
}

0 commit comments

Comments
 (0)