Skip to content

Commit e8b0c4d

Browse files
committed
unit testing for single linkage clustering
1 parent deb0311 commit e8b0c4d

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

src/main/java/com/github/chen0040/clustering/onelink/SingleLinkageClustering.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
@Getter
1818
@Setter
19-
public class SingleLinkageClustering{
19+
public class SingleLinkageClustering {
2020
private int clusterCount = 10;
2121
private BiFunction<DataRow, double[], Double> distanceMeasure;
2222

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package com.github.chen0040.clustering.onelink;
2+
3+
4+
import com.github.chen0040.clustering.density.DBSCAN;
5+
import com.github.chen0040.data.frame.DataFrame;
6+
import com.github.chen0040.data.frame.DataQuery;
7+
import com.github.chen0040.data.frame.DataRow;
8+
import com.github.chen0040.data.frame.Sampler;
9+
import org.testng.annotations.Test;
10+
11+
import java.util.Random;
12+
13+
import static org.testng.Assert.*;
14+
15+
16+
/**
17+
* Created by xschen on 2/6/2017.
18+
*/
19+
public class SingleLinkageClusteringUnitTest {
20+
private static Random random = new Random();
21+
22+
public static double rand(){
23+
return random.nextDouble();
24+
}
25+
26+
public static double rand(double lower, double upper){
27+
return rand() * (upper - lower) + lower;
28+
}
29+
30+
public static double randn(){
31+
double u1 = rand();
32+
double u2 = rand();
33+
double r = Math.sqrt(-2.0 * Math.log(u1));
34+
double theta = 2.0 * Math.PI * u2;
35+
return r * Math.sin(theta);
36+
}
37+
38+
39+
// unit testing based on example from http://scikit-learn.org/stable/auto_examples/svm/plot_oneclass.html#
40+
@Test
41+
public void testSimple(){
42+
43+
44+
DataQuery.DataFrameQueryBuilder schema = DataQuery.blank()
45+
.newInput("c1")
46+
.newInput("c2")
47+
.newOutput("designed")
48+
.end();
49+
50+
Sampler.DataSampleBuilder negativeSampler = new Sampler()
51+
.forColumn("c1").generate((name, index) -> randn() * 0.3 + (index % 2 == 0 ? 2 : 4))
52+
.forColumn("c2").generate((name, index) -> randn() * 0.3 + (index % 2 == 0 ? 2 : 4))
53+
.forColumn("designed").generate((name, index) -> 0.0)
54+
.end();
55+
56+
Sampler.DataSampleBuilder positiveSampler = new Sampler()
57+
.forColumn("c1").generate((name, index) -> rand(-4, -2))
58+
.forColumn("c2").generate((name, index) -> rand(-2, -4))
59+
.forColumn("designed").generate((name, index) -> 1.0)
60+
.end();
61+
62+
DataFrame data = schema.build();
63+
64+
data = negativeSampler.sample(data, 50);
65+
data = positiveSampler.sample(data, 50);
66+
67+
System.out.println(data.head(10));
68+
69+
SingleLinkageClustering algorithm = new SingleLinkageClustering();
70+
algorithm.setClusterCount(2);
71+
72+
DataFrame learnedData = algorithm.fitAndTransform(data);
73+
74+
for(int i = 0; i < learnedData.rowCount(); ++i){
75+
DataRow tuple = learnedData.row(i);
76+
String clusterId = tuple.getCategoricalTargetCell("cluster");
77+
System.out.println("learned: " + clusterId +"\tknown: "+tuple.target());
78+
}
79+
80+
81+
}
82+
}

0 commit comments

Comments
 (0)