Skip to content

Commit 0e64614

Browse files
author
Tom
committed
Added some JUnit tests
1 parent b4f7fcf commit 0e64614

File tree

11 files changed

+199
-95
lines changed

11 files changed

+199
-95
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Expectation Maximization for Geographical Data
2+
3+
An implementation of the expectation-maximization algorithm for a mixture of
4+
gaussians, designed to organize geographical data into clusters.
5+
6+
## To Do
7+
* Fixing issues with NaN
8+
* Mean initialization
9+
* Heuristic for identification of number of clusters
10+
* Testing on real-world data

java/.classpath

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
<classpath>
33
<classpathentry kind="src" path=""/>
44
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
5+
<classpathentry kind="con" path="org.eclipse.jdt.USER_LIBRARY/Apache Commons Math"/>
56
<classpathentry kind="lib" path="commons-math3-3.1.1.jar"/>
7+
<classpathentry kind="con" path="org.eclipse.jdt.junit.JUNIT_CONTAINER/4"/>
68
<classpathentry kind="output" path=""/>
79
</classpath>

java/ExpMax.java renamed to java/edu/cmu/ml/geoEM/ExpMax.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
package edu.cmu.ml.geoEM;
12
/*
23
* ExpMax.java
34
* Description: A implementation of the expectation-maximization algorithm
@@ -19,7 +20,7 @@ public class ExpMax
1920
public int numPoints;
2021
public double[] means;
2122
public double[] sigmas;
22-
private double epsilon = 0.01;
23+
protected double epsilon = 0.01;
2324

2425

2526
public ExpMax(double[] data, int k) {
@@ -51,7 +52,7 @@ public double[][] calculateParameters() {
5152
return new double[][] {means, sigmas};
5253
}
5354

54-
private boolean compare(double[] curr, double[] old) {
55+
protected boolean compare(double[] curr, double[] old) {
5556
if(old == null)
5657
return true;
5758
double diff = 0.0;
@@ -60,15 +61,15 @@ private boolean compare(double[] curr, double[] old) {
6061
return FastMath.pow(diff, 0.5) > epsilon;
6162
}
6263

63-
private double[][] calculateExpectation() {
64+
protected double[][] calculateExpectation() {
6465
double[][] expectedValues = new double[numDist][numPoints];
6566
for(int i = 0; i < numDist; i++)
6667
for(int j = 0; j < numPoints; j++)
6768
expectedValues[i][j] = expectedValuePoint(data[j], i);
6869
return expectedValues;
6970
}
7071

71-
private double expectedValuePoint(double point, int currDist) {
72+
protected double expectedValuePoint(double point, int currDist) {
7273
double probCurrDist = probPoint(point, means[currDist], sigmas[currDist]);
7374
double probAllDist = 0;
7475
for(int i = 0; i < numDist; i++)
@@ -78,12 +79,12 @@ private double expectedValuePoint(double point, int currDist) {
7879
return probCurrDist / probAllDist;
7980
}
8081

81-
private static double probPoint(double point, double mean, double sigma) {
82+
protected static double probPoint(double point, double mean, double sigma) {
8283
NormalDistribution dist = new NormalDistribution(mean, sigma);
8384
return dist.density(point);
8485
}
8586

86-
private void calculateHypothesis(double[][] expectedValues) {
87+
protected void calculateHypothesis(double[][] expectedValues) {
8788
for(int i = 0; i < numDist; i++) {
8889
double totalExp = 0;
8990
means[i] = 0;

java/edu/cmu/ml/geoEM/ExpMaxTest.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package edu.cmu.ml.geoEM;
2+
3+
import static org.junit.Assert.*;
4+
5+
import java.io.IOException;
6+
import java.text.DecimalFormat;
7+
import java.util.Arrays;
8+
9+
import org.junit.Test;
10+
11+
public class ExpMaxTest {
12+
@Test
13+
public void testexpectedValuePoint() {
14+
double copyValueFromPython = 0.00539909665132;
15+
assertEquals(copyValueFromPython,ExpMax.probPoint(20, 0, 10),1e-4);
16+
17+
// randomly generated test values in Python
18+
copyValueFromPython = 6.66558583233e-17;
19+
assertEquals(copyValueFromPython,ExpMax.probPoint(-34, 89, 15),1e-4);
20+
21+
copyValueFromPython = 1.2410076451e-6;
22+
assertEquals(copyValueFromPython,ExpMax.probPoint(-28, 0, 6),1e-4);
23+
24+
copyValueFromPython = 0.00437031484895;
25+
assertEquals(copyValueFromPython,ExpMax.probPoint(-61, -75, 6),1e-4);
26+
27+
copyValueFromPython = 0.0419314697437;
28+
assertEquals(copyValueFromPython,ExpMax.probPoint(-55, -52, 9),1e-4);
29+
}
30+
31+
@Test
32+
public void testExpMax() throws IOException {
33+
int k = 2;
34+
// Util.exportFile("temp.txt", KGauss.kgauss(k, 100, 1, -100, 100, 0.1));
35+
double[] data = Util.importFile("temp.0")[0];
36+
double[][] pythonParams = new double[][] { { 5.99909154243, 56.0399394117, }, { 6.15773090876, 5.34897679826, } };
37+
ExpMax em = new ExpMax(data, k);
38+
double[][] params = em.calculateParameters();
39+
Util.roundArray(pythonParams);
40+
Util.roundArray(params);
41+
assertArrayEquals(pythonParams, params);
42+
43+
data = Util.importFile("temp.1")[0];
44+
pythonParams = new double[][] { { -0.410591419827, 28.5893379139, }, { 6.55528895793, 6.84560112468, }, };
45+
em = new ExpMax(data, k);
46+
params = em.calculateParameters();
47+
Util.roundArray(pythonParams);
48+
Util.roundArray(params);
49+
assertArrayEquals(pythonParams, params);
50+
51+
data = Util.importFile("temp.2")[0];
52+
pythonParams = new double[][] { { -94.4554393446, 11.6315567616, }, { 7.71865613213, 7.79345596918, }, }; em = new ExpMax(data, k);
53+
params = em.calculateParameters();
54+
Util.roundArray(pythonParams);
55+
Util.roundArray(params);
56+
assertArrayEquals(pythonParams, params);
57+
}
58+
59+
}

java/KGauss.java renamed to java/edu/cmu/ml/geoEM/KGauss.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
package edu.cmu.ml.geoEM;
12
import java.util.*;
23
import java.io.*;
34
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;

java/MultiExpMax.java renamed to java/edu/cmu/ml/geoEM/MultiExpMax.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
package edu.cmu.ml.geoEM;
12
import java.io.IOException;
23
import java.util.*;
34

@@ -22,15 +23,14 @@ public MultiExpMax(double[][] data, int k) {
2223

2324
means = new double[numDist][dim];
2425
/* TODO: initialize means */
25-
means = new double[][] {{-40, 60},
26-
{-50, 30}};
26+
means = new double[][] {{-80.0,40.0,},{-80.0,-10.0,},};
2727

2828
covs = new RealMatrix[numDist];
2929
for(int i = 0; i < numDist; i++) {
3030
covs[i] = new Array2DRowRealMatrix(new double[dim][dim]);
3131
for(int r = 0; r < dim; r++)
3232
for(int c = 0; c < dim; c++)
33-
if(r == c) covs[i].setEntry(r, c, 1.0);
33+
if(r == c) covs[i].setEntry(r, c, 10.0);
3434
}
3535
}
3636

@@ -43,6 +43,7 @@ public void calculateParameters() {
4343
oldMeans = Util.deepcopy(means);
4444
oldCovs = Util.deepcopy(covs);
4545
calculateHypothesis(calculateExpectation());
46+
System.out.println(Util.arrayToString(covs[0].getData()));
4647
} while(compare(means, oldMeans) || compare(covs, oldCovs));
4748
System.out.println("Done! Means: " + Util.arrayToString(means));
4849
}
@@ -100,7 +101,7 @@ private double expectedValuePoint(double[] point, int currDist) {
100101
return probCurrDist / probAllDist;
101102
}
102103

103-
private static double probPoint(double[] point,
104+
public static double probPoint(double[] point,
104105
double[] means, RealMatrix cov) {
105106
MultivariateNormalDistribution dist =
106107
new MultivariateNormalDistribution(means, cov.getData());
@@ -123,9 +124,10 @@ private void calculateHypothesis(double[][] expectedValues) {
123124
for(int r = 0; r < dim; r++) {
124125
for(int c = 0; c < dim; c++) {
125126
for(int j = 0; j < numPoints; j++) {
126-
covs[i].setEntry(r, c, expectedValues[i][j]
127-
* (data[r][j] - means[i][r])
128-
* (data[c][j] - means[i][c]));
127+
double entry = (expectedValues[i][j]
128+
* (data[r][j] - means[i][r])
129+
* (data[c][j] - means[i][c]));
130+
covs[i].setEntry(r, c, entry);
129131
}
130132
covs[i].setEntry(r, c, covs[i].getEntry(r, c) /
131133
(totalExp * (numPoints - 1) / numPoints));
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package edu.cmu.ml.geoEM;
2+
3+
import static org.junit.Assert.*;
4+
5+
import java.io.IOException;
6+
7+
import org.apache.commons.math3.linear.*;
8+
9+
import org.junit.Test;
10+
11+
public class MultiExpMaxTest {
12+
13+
@Test
14+
public void testProbPoint() {
15+
double copyValueFromPython = 6.20250972071e-147;
16+
assertEquals(copyValueFromPython,MultiExpMax.probPoint(new double[] {-27,-6,} , new double[] {-9,-90,}, new Array2DRowRealMatrix(new double[][] {{14,0,},{0,11,},})),1e-150);
17+
18+
copyValueFromPython = 2.95791812698e-315;
19+
assertEquals(copyValueFromPython,MultiExpMax.probPoint(new double[] {-81,-25,} , new double[] {-98,81,}, new Array2DRowRealMatrix(new double[][] {{8,0,},{0,8,},})),1e-320);
20+
21+
copyValueFromPython = 7.80714811896e-167;
22+
assertEquals(copyValueFromPython,MultiExpMax.probPoint(new double[] {49,8,} , new double[] {-21,-41,}, new Array2DRowRealMatrix(new double[][] {{10,0,},{0,9,},})),1e-170);
23+
}
24+
25+
@Test
26+
public void testExpMax() throws IOException {
27+
double[][] data = Util.importFile("temp.m");
28+
MultiExpMax em = new MultiExpMax(data, 2);
29+
em.calculateParameters();
30+
31+
System.out.println("Model means: " + Util.arrayToString(em.means));
32+
for(RealMatrix m : em.covs)
33+
System.out.println("Model cov: "
34+
+ Util.arrayToString(m.getData()));
35+
}
36+
}

java/Util.java renamed to java/edu/cmu/ml/geoEM/Util.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
package edu.cmu.ml.geoEM;
12
import java.io.*;
23
import java.nio.MappedByteBuffer;
34
import java.nio.channels.FileChannel;
45
import java.nio.charset.Charset;
6+
import java.text.DecimalFormat;
57
import java.util.*;
68
import org.apache.commons.math3.linear.*;
79

@@ -87,4 +89,18 @@ public static String exportFile(String filename, double[][] data) {
8789
}
8890
return filename;
8991
}
92+
93+
public static double round(double d, int i) {
94+
String s = "#.";
95+
while(i-- > 0)
96+
s += "#";
97+
DecimalFormat twoDForm = new DecimalFormat(s);
98+
return Double.valueOf(twoDForm.format(d));
99+
}
100+
101+
public static void roundArray(double[][] arr) {
102+
for(int i = 0; i < arr.length; i++)
103+
for(int j = 0; j < arr[0].length; j++)
104+
arr[i][j] = round(arr[i][j], 1);
105+
}
90106
}

python/expmax.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,11 @@ def initial_hypothesis_mu(data, k):
3434
h = []
3535
minval = np.min(data)
3636
maxval = np.max(data)
37-
#"""
3837
interval = (maxval - minval) / k
3938
h.append(minval)
4039
for i in xrange(1, k-1):
4140
h.append(minval + i * interval)
4241
h.append(maxval)
43-
"""
44-
for i in xrange(k):
45-
h.append(randint(int(minval), int(maxval)))
46-
"""
4742
return h
4843

4944
def compare_hypothesis(curr, old, epsilon=0.01):
@@ -108,28 +103,27 @@ def test():
108103
diffs[i] += difflist[j][i]
109104
diffs[i] /= 10
110105

111-
def main():
112-
#data, mus = kgauss(2, 100, dim=1, lower=-100, upper=100, sigma=3)
113-
#exportFile('temp.txt', data)
114-
# data = importFile('temp.txt')
115-
# test()
116-
"""
117-
data = []
118-
for i in kgauss(1, 10, dim=1, lower=-100, upper=100, sigma=3)[0]:
119-
data.append(i)
120-
for i in kgauss(1, 1000, dim=1, lower=-100, upper=100, sigma=3)[0]:
121-
data.append(i)
122-
em = expectation_maximization(np.array(data), 2, 1)
123-
while not (em[0] and em[1]):
124-
em = expectation_maximization(np.array(data), 2, 1)
125-
print 'Let\'s try again'
126-
print em
127-
"""
128-
data, mus, sigmas = kgauss_with_mus_sigmas(2, 100, dim=1, lower=-100, upper=100)
129-
exportFile('temp.txt', data)
130-
print mus, sigmas
131-
print expectation_maximization(data[0], 2, [1, 1])
106+
def test_ppg():
107+
point, mean, sigma = randint(-100, 100), randint(-100, 100), randint(5, 20)
108+
print 'copyValueFromPython = ' + str(prob_point_gauss(point, mean, sigma)) + ';'
109+
print 'assertEquals(copyValueFromPython,ExpMax.probPoint(' + str(point) + ', ' + str(mean) + ', ' + str(sigma) +'),1e-4);'
132110

111+
def test_em():
112+
filename = '..\\Java\\temp2.txt'
113+
data = kgauss(2, 100, dim=1, lower=-100, upper=100, sigma=randint(5, 15))
114+
exportFile(filename, data)
115+
# print data[0]
116+
params = expectation_maximization(data[0], 2, [1, 1])
117+
print 'pythonParams = new double[][] {',
118+
for i in params:
119+
print '{',
120+
for j in i:
121+
print str(j) + ',',
122+
print '},',
123+
print '};'
124+
125+
def main():
126+
test_em()
133127

134128
if __name__ == "__main__":
135129
main()

0 commit comments

Comments
 (0)