Skip to content

Commit 25f7022

Browse files
Ben EvansBen Evans
authored andcommitted
Decoupled classifier and FeatureSelection, by moving the classifier into FeatyureSelection instead of constantly being passed around
1 parent 89e89fd commit 25f7022

File tree

6 files changed

+91
-91
lines changed

6 files changed

+91
-91
lines changed

Classifier.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import java.io.File;
2-
import java.io.FileNotFoundException;
31
import java.util.Comparator;
42
import java.util.HashMap;
53
import java.util.HashSet;
64
import java.util.PriorityQueue;
7-
import java.util.Scanner;
85
import java.util.Set;
6+
import java.util.stream.Collectors;
7+
import java.util.stream.IntStream;
98

109

1110
/**
@@ -30,6 +29,21 @@ public Classifier(Set<Instance> instances) {
3029
}
3130
}
3231

32+
/**
33+
* Classifies and calculates the percentage
34+
* of correct classifications in the testingSet
35+
* against the training set.
36+
*/
37+
public double classify(){
38+
Instance sampleInstance = training.iterator().next();
39+
int totalFeatures = sampleInstance.getNumFeatures();
40+
41+
// To begin with all features are selected
42+
Set<Integer> allIndices = IntStream.rangeClosed(0, totalFeatures - 1)
43+
.boxed().collect(Collectors.toSet());
44+
45+
return classify(allIndices);
46+
}
3347
/**
3448
* Classifies and calculates the percentage
3549
* of correct classifications in the testingSet
@@ -70,6 +84,7 @@ public int compare(Result a, Result b) {
7084
return correct/(double)testing.size();
7185
}
7286

87+
7388
/**
7489
* Returns the mode of @param list
7590
* @param list

Criteria.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import java.util.Set;
2+
3+
public interface Criteria {
4+
public boolean evaluate(Set<Integer> features, int size);
5+
}

FeatureSelection.java

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,26 @@
88
*/
99
public abstract class FeatureSelection {
1010

11+
private Classifier classifier;
12+
protected Set<Instance> instances;
13+
14+
public FeatureSelection(Set<Instance> instances){
15+
this.instances = instances;
16+
this.classifier = new Classifier(instances);
17+
}
1118

1219
/**
1320
Returns a subset of only the most important features
1421
chosen by some measure.
1522
*/
16-
public abstract Set<Integer> select(Set<Instance> instances, int numFeaturesToSelect);
23+
public abstract Set<Integer> select(int numFeaturesToSelect);
1724

1825
/**
1926
Returns a subset containing only the numFeatures most important
2027
features. If numFeatures is >= original.size(), the original
2128
set is returned.
2229
*/
23-
public abstract Set<Integer> select(Set<Instance> instances, double goalAccuracy);
30+
public abstract Set<Integer> select(double goalAccuracy);
2431

2532
/**
2633
* Returns the feature in remaining features
@@ -30,15 +37,15 @@ public abstract class FeatureSelection {
3037
* @param remainingFeatures
3138
* @return
3239
*/
33-
protected int best(Classifier classifier, Set<Integer> selectedFeatures, Set<Integer> remainingFeatures){
40+
protected int best(Set<Integer> selectedFeatures, Set<Integer> remainingFeatures){
3441
double highest = -Integer.MAX_VALUE;
3542
int selected = -1;
3643

3744
for(int feature: remainingFeatures){
3845
Set<Integer> newFeatures = new HashSet<>(selectedFeatures);
3946
newFeatures.add(feature);
4047

41-
double result = objectiveFunction(classifier, newFeatures);
48+
double result = objectiveFunction(newFeatures);
4249
if(result > highest){
4350
highest = result;
4451
selected = feature;
@@ -55,15 +62,15 @@ protected int best(Classifier classifier, Set<Integer> selectedFeatures, Set<Int
5562
* @param features
5663
* @return
5764
*/
58-
protected int worst(Classifier classifier, Set<Integer> features){
65+
protected int worst(Set<Integer> features){
5966
double lowestAccuracy = Integer.MAX_VALUE;
6067
int selected = -1;
6168

6269
for(int feature: features){
6370
Set<Integer> newFeatures = new HashSet<>(features);
6471
newFeatures.remove(feature);
6572

66-
double result = objectiveFunction(classifier, newFeatures);
73+
double result = objectiveFunction(newFeatures);
6774
if(result < lowestAccuracy){
6875
lowestAccuracy = result;
6976
selected = feature;
@@ -73,15 +80,23 @@ protected int worst(Classifier classifier, Set<Integer> features){
7380
return selected;
7481
}
7582

76-
public void compareAccuracy(Classifier classifier, Set<Integer> selectedIndices, int totalFeatures) {
77-
Set<Integer> allIndices = IntStream.rangeClosed(0, totalFeatures - 1)
83+
protected double objectiveFunction(Set<Integer> selectedFeatures) {
84+
return classifier.classify(selectedFeatures);
85+
}
86+
87+
protected Set<Integer> getFeatures(){
88+
// Extract an instance to check the amount of features, assumes all instances have same # of features
89+
Instance sampleInstance = instances.iterator().next();
90+
int totalFeatures = sampleInstance.getNumFeatures();
91+
92+
// To begin with all features are selected
93+
return IntStream.rangeClosed(0, totalFeatures - 1)
7894
.boxed().collect(Collectors.toSet());
79-
System.out.println("Classification accuracy on testing set using all features: " + classifier.classify(allIndices));
80-
System.out.println("Classification accuracy on testing set using features " + selectedIndices + ": " + classifier.classify(selectedIndices));
8195
}
8296

83-
protected double objectiveFunction(Classifier classifier, Set<Integer> selectedFeatures) {
84-
return classifier.classify(selectedFeatures);
97+
public void compareAccuracy(Set<Integer> selectedIndices) {
98+
System.out.println("Classification accuracy on testing set using all features: " + classifier.classify());
99+
System.out.println("Classification accuracy on testing set using features " + selectedIndices + ": " + classifier.classify(selectedIndices));
85100
}
86101

87102

Instance.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ public Instance(double[] features, String label){
2727

2828
public double getFeature(int index){ return features[index]; }
2929

30+
public Instance createInstance(Set<Integer> featureIndices){
31+
double[] features = new double[featureIndices.size()];
32+
33+
int i=0;
34+
for(int index: featureIndices){
35+
features[i++] = features[index];
36+
}
37+
38+
return new Instance(features , label);
39+
}
40+
3041
public double distanceTo(Instance other, Set<Integer> indices){
3142
if(getNumFeatures() != other.getNumFeatures()) throw new IllegalArgumentException("Number of features do not match");
3243

SequentialBackwardsSelection.java

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
11
import java.util.HashSet;
22
import java.util.Set;
3-
import java.util.stream.Collectors;
4-
import java.util.stream.IntStream;
53

64
/**
75
* Created by ben on 18/04/17.
86
*/
97
public class SequentialBackwardsSelection extends FeatureSelection {
108

11-
@Override
12-
public Set<Integer> select(Set<Instance> instances, int numFeaturesToSelect) {
13-
// In this case we have no data to use, so return the empty set
14-
if(instances == null || instances.isEmpty()) return new HashSet<Integer>();
9+
public SequentialBackwardsSelection(Set<Instance> instances){
10+
super(instances);
11+
}
1512

16-
// Extract an instance to check the amount of features, assumes all instances have same # of features
17-
Instance sampleInstance = instances.iterator().next();
18-
int totalFeatures = sampleInstance.getNumFeatures();
13+
@Override
14+
public Set<Integer> select(int numFeaturesToSelect) {
15+
return select((features, size) -> size > numFeaturesToSelect);
16+
}
1917

20-
// To begin with all features are selected
21-
Set<Integer> selectedFeatures = IntStream.rangeClosed(0, totalFeatures - 1)
22-
.boxed().collect(Collectors.toSet());
18+
@Override
19+
public Set<Integer> select(double goalAccuracy) {
20+
return select((features, size) -> objectiveFunction(features) < goalAccuracy);
21+
}
2322

24-
// Nothing we can do if the number of features to select is greater than or equal to the total size
25-
if (numFeaturesToSelect >= totalFeatures){
26-
return selectedFeatures;
27-
}
23+
public Set<Integer> select(Criteria criteria) {
24+
// In this case we have no data to use, so return the empty set
25+
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
2826

29-
Classifier classifier = new Classifier(instances);
27+
// To begin with all features are selected
28+
Set<Integer> selectedFeatures = getFeatures();
3029

31-
while (selectedFeatures.size() >= numFeaturesToSelect){
32-
int feature = worst(classifier, selectedFeatures);
30+
while (criteria.evaluate(selectedFeatures, selectedFeatures.size())){
31+
int feature = worst(selectedFeatures);
3332

3433
// No more valid features
3534
if (feature == -1) break;
@@ -40,9 +39,4 @@ public Set<Integer> select(Set<Instance> instances, int numFeaturesToSelect) {
4039

4140
return selectedFeatures;
4241
}
43-
44-
@Override
45-
public Set<Integer> select(Set<Instance> instances, double goalAccuracy) {
46-
return null;
47-
}
4842
}

SequentialForwardSelection.java

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,47 @@
11
import java.util.HashSet;
22
import java.util.Set;
3-
import java.util.stream.Collectors;
4-
import java.util.stream.IntStream;
53

64
/**
75
* Created by ben on 8/04/17.
86
*/
97
public class SequentialForwardSelection extends FeatureSelection {
108

11-
public Set<Integer> select(final Set<Instance> instances, int numFeaturesToSelect) {
12-
// In this case we have no data to use, so return the empty set
13-
if(instances == null || instances.isEmpty()) return new HashSet<Integer>();
14-
15-
// Extract an instance to check the amount of features, assumes all instances have same # of features
16-
Instance sampleInstance = instances.iterator().next();
17-
int totalFeatures = sampleInstance.getNumFeatures();
18-
19-
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
20-
Set<Integer> remainingFeatures = IntStream.rangeClosed(0, totalFeatures - 1)
21-
.boxed().collect(Collectors.toSet());
22-
23-
// Nothing we can do if the number of features to select is greater than or equal to the total size
24-
if (numFeaturesToSelect >= totalFeatures){
25-
return remainingFeatures;
26-
}
27-
28-
// Subset of only selected features indices
29-
Set<Integer> selectedFeatures = new HashSet<>();
30-
31-
Classifier classifier = new Classifier(instances);
32-
33-
while (selectedFeatures.size() < numFeaturesToSelect){
34-
int feature = best(classifier, selectedFeatures, remainingFeatures);
35-
36-
// No more valid features
37-
if (feature == -1) break;
9+
public SequentialForwardSelection(Set<Instance> instances){
10+
super(instances);
11+
}
3812

39-
selectedFeatures.add(feature);
40-
// Remove the feature so we do not keep selecting the same one
41-
remainingFeatures.remove(feature);
42-
}
13+
public Set<Integer> select(int numFeaturesToSelect) {
14+
return select((features, size) -> size < numFeaturesToSelect);
15+
}
4316

44-
compareAccuracy(classifier, selectedFeatures, totalFeatures);
45-
return selectedFeatures;
17+
public Set<Integer> select(double goalAccuracy) {
18+
return select((features, size) -> objectiveFunction( features) < goalAccuracy);
4619
}
4720

48-
public Set<Integer> select(final Set<Instance> instances, double goalAccuracy) {
21+
public Set<Integer> select(Criteria criteria) {
4922
// In this case we have no data to use, so return the empty set
5023
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
5124

52-
// Extract an instance to check the amount of features, assumes all instances have same # of features
53-
Instance sampleInstance = instances.iterator().next();
54-
int totalFeatures = sampleInstance.getNumFeatures();
55-
5625
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
57-
Set<Integer> remainingFeatures = IntStream.rangeClosed(0, totalFeatures - 1)
58-
.boxed().collect(Collectors.toSet());
26+
Set<Integer> remainingFeatures = getFeatures();
5927

6028
// Subset of only selected features indices
6129
Set<Integer> selectedFeatures = new HashSet<>();
6230

63-
// Track classifiction accuracy
64-
double accuracy = 0;
65-
6631
Classifier classifier = new Classifier(instances);
6732

68-
while (accuracy < goalAccuracy){
69-
int feature = best(classifier, selectedFeatures, remainingFeatures);
33+
while (criteria.evaluate(selectedFeatures, selectedFeatures.size())){
34+
int feature = best(selectedFeatures, remainingFeatures);
7035
// No more valid features
7136
if (feature == -1) break;
7237

7338
selectedFeatures.add(feature);
7439
// Remove the feature so we do not keep selecting the same one
7540
remainingFeatures.remove(feature);
76-
77-
accuracy = objectiveFunction(classifier, selectedFeatures);
7841
}
7942

80-
compareAccuracy(classifier, selectedFeatures, totalFeatures);
8143
return selectedFeatures;
8244
}
8345

8446

85-
86-
8747
}

0 commit comments

Comments
 (0)