Skip to content

Commit 25f7022

Browse files
Ben EvansBen Evans
Ben Evans
authored and
Ben Evans
committed
Decoupled classifier and FeatureSelection, by moving the classifier into FeatyureSelection instead of constantly being passed around
1 parent 89e89fd commit 25f7022

6 files changed

+91
-91
lines changed

Classifier.java

+18-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import java.io.File;
2-
import java.io.FileNotFoundException;
31
import java.util.Comparator;
42
import java.util.HashMap;
53
import java.util.HashSet;
64
import java.util.PriorityQueue;
7-
import java.util.Scanner;
85
import java.util.Set;
6+
import java.util.stream.Collectors;
7+
import java.util.stream.IntStream;
98

109

1110
/**
@@ -30,6 +29,21 @@ public Classifier(Set<Instance> instances) {
3029
}
3130
}
3231

32+
/**
33+
* Classifies and calculates the percentage
34+
* of correct classifications in the testingSet
35+
* against the training set.
36+
*/
37+
public double classify(){
38+
Instance sampleInstance = training.iterator().next();
39+
int totalFeatures = sampleInstance.getNumFeatures();
40+
41+
// To begin with all features are selected
42+
Set<Integer> allIndices = IntStream.rangeClosed(0, totalFeatures - 1)
43+
.boxed().collect(Collectors.toSet());
44+
45+
return classify(allIndices);
46+
}
3347
/**
3448
* Classifies and calculates the percentage
3549
* of correct classifications in the testingSet
@@ -70,6 +84,7 @@ public int compare(Result a, Result b) {
7084
return correct/(double)testing.size();
7185
}
7286

87+
7388
/**
7489
* Returns the mode of @param list
7590
* @param list

Criteria.java

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import java.util.Set;
2+
3+
public interface Criteria {
4+
public boolean evaluate(Set<Integer> features, int size);
5+
}

FeatureSelection.java

+27-12
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,26 @@
88
*/
99
public abstract class FeatureSelection {
1010

11+
private Classifier classifier;
12+
protected Set<Instance> instances;
13+
14+
public FeatureSelection(Set<Instance> instances){
15+
this.instances = instances;
16+
this.classifier = new Classifier(instances);
17+
}
1118

1219
/**
1320
Returns a subset of only the most important features
1421
chosen by some measure.
1522
*/
16-
public abstract Set<Integer> select(Set<Instance> instances, int numFeaturesToSelect);
23+
public abstract Set<Integer> select(int numFeaturesToSelect);
1724

1825
/**
1926
Returns a subset containing only the numFeatures most important
2027
features. If numFeatures is >= original.size(), the original
2128
set is returned.
2229
*/
23-
public abstract Set<Integer> select(Set<Instance> instances, double goalAccuracy);
30+
public abstract Set<Integer> select(double goalAccuracy);
2431

2532
/**
2633
* Returns the feature in remaining features
@@ -30,15 +37,15 @@ public abstract class FeatureSelection {
3037
* @param remainingFeatures
3138
* @return
3239
*/
33-
protected int best(Classifier classifier, Set<Integer> selectedFeatures, Set<Integer> remainingFeatures){
40+
protected int best(Set<Integer> selectedFeatures, Set<Integer> remainingFeatures){
3441
double highest = -Integer.MAX_VALUE;
3542
int selected = -1;
3643

3744
for(int feature: remainingFeatures){
3845
Set<Integer> newFeatures = new HashSet<>(selectedFeatures);
3946
newFeatures.add(feature);
4047

41-
double result = objectiveFunction(classifier, newFeatures);
48+
double result = objectiveFunction(newFeatures);
4249
if(result > highest){
4350
highest = result;
4451
selected = feature;
@@ -55,15 +62,15 @@ protected int best(Classifier classifier, Set<Integer> selectedFeatures, Set<Int
5562
* @param features
5663
* @return
5764
*/
58-
protected int worst(Classifier classifier, Set<Integer> features){
65+
protected int worst(Set<Integer> features){
5966
double lowestAccuracy = Integer.MAX_VALUE;
6067
int selected = -1;
6168

6269
for(int feature: features){
6370
Set<Integer> newFeatures = new HashSet<>(features);
6471
newFeatures.remove(feature);
6572

66-
double result = objectiveFunction(classifier, newFeatures);
73+
double result = objectiveFunction(newFeatures);
6774
if(result < lowestAccuracy){
6875
lowestAccuracy = result;
6976
selected = feature;
@@ -73,15 +80,23 @@ protected int worst(Classifier classifier, Set<Integer> features){
7380
return selected;
7481
}
7582

76-
public void compareAccuracy(Classifier classifier, Set<Integer> selectedIndices, int totalFeatures) {
77-
Set<Integer> allIndices = IntStream.rangeClosed(0, totalFeatures - 1)
83+
protected double objectiveFunction(Set<Integer> selectedFeatures) {
84+
return classifier.classify(selectedFeatures);
85+
}
86+
87+
protected Set<Integer> getFeatures(){
88+
// Extract an instance to check the amount of features, assumes all instances have same # of features
89+
Instance sampleInstance = instances.iterator().next();
90+
int totalFeatures = sampleInstance.getNumFeatures();
91+
92+
// To begin with all features are selected
93+
return IntStream.rangeClosed(0, totalFeatures - 1)
7894
.boxed().collect(Collectors.toSet());
79-
System.out.println("Classification accuracy on testing set using all features: " + classifier.classify(allIndices));
80-
System.out.println("Classification accuracy on testing set using features " + selectedIndices + ": " + classifier.classify(selectedIndices));
8195
}
8296

83-
protected double objectiveFunction(Classifier classifier, Set<Integer> selectedFeatures) {
84-
return classifier.classify(selectedFeatures);
97+
public void compareAccuracy(Set<Integer> selectedIndices) {
98+
System.out.println("Classification accuracy on testing set using all features: " + classifier.classify());
99+
System.out.println("Classification accuracy on testing set using features " + selectedIndices + ": " + classifier.classify(selectedIndices));
85100
}
86101

87102

Instance.java

+11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ public Instance(double[] features, String label){
2727

2828
public double getFeature(int index){ return features[index]; }
2929

30+
public Instance createInstance(Set<Integer> featureIndices){
31+
double[] features = new double[featureIndices.size()];
32+
33+
int i=0;
34+
for(int index: featureIndices){
35+
features[i++] = features[index];
36+
}
37+
38+
return new Instance(features , label);
39+
}
40+
3041
public double distanceTo(Instance other, Set<Integer> indices){
3142
if(getNumFeatures() != other.getNumFeatures()) throw new IllegalArgumentException("Number of features do not match");
3243

SequentialBackwardsSelection.java

+18-24
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
11
import java.util.HashSet;
22
import java.util.Set;
3-
import java.util.stream.Collectors;
4-
import java.util.stream.IntStream;
53

64
/**
75
* Created by ben on 18/04/17.
86
*/
97
public class SequentialBackwardsSelection extends FeatureSelection {
108

11-
@Override
12-
public Set<Integer> select(Set<Instance> instances, int numFeaturesToSelect) {
13-
// In this case we have no data to use, so return the empty set
14-
if(instances == null || instances.isEmpty()) return new HashSet<Integer>();
9+
public SequentialBackwardsSelection(Set<Instance> instances){
10+
super(instances);
11+
}
1512

16-
// Extract an instance to check the amount of features, assumes all instances have same # of features
17-
Instance sampleInstance = instances.iterator().next();
18-
int totalFeatures = sampleInstance.getNumFeatures();
13+
@Override
14+
public Set<Integer> select(int numFeaturesToSelect) {
15+
return select((features, size) -> size > numFeaturesToSelect);
16+
}
1917

20-
// To begin with all features are selected
21-
Set<Integer> selectedFeatures = IntStream.rangeClosed(0, totalFeatures - 1)
22-
.boxed().collect(Collectors.toSet());
18+
@Override
19+
public Set<Integer> select(double goalAccuracy) {
20+
return select((features, size) -> objectiveFunction(features) < goalAccuracy);
21+
}
2322

24-
// Nothing we can do if the number of features to select is greater than or equal to the total size
25-
if (numFeaturesToSelect >= totalFeatures){
26-
return selectedFeatures;
27-
}
23+
public Set<Integer> select(Criteria criteria) {
24+
// In this case we have no data to use, so return the empty set
25+
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
2826

29-
Classifier classifier = new Classifier(instances);
27+
// To begin with all features are selected
28+
Set<Integer> selectedFeatures = getFeatures();
3029

31-
while (selectedFeatures.size() >= numFeaturesToSelect){
32-
int feature = worst(classifier, selectedFeatures);
30+
while (criteria.evaluate(selectedFeatures, selectedFeatures.size())){
31+
int feature = worst(selectedFeatures);
3332

3433
// No more valid features
3534
if (feature == -1) break;
@@ -40,9 +39,4 @@ public Set<Integer> select(Set<Instance> instances, int numFeaturesToSelect) {
4039

4140
return selectedFeatures;
4241
}
43-
44-
@Override
45-
public Set<Integer> select(Set<Instance> instances, double goalAccuracy) {
46-
return null;
47-
}
4842
}

SequentialForwardSelection.java

+12-52
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,47 @@
11
import java.util.HashSet;
22
import java.util.Set;
3-
import java.util.stream.Collectors;
4-
import java.util.stream.IntStream;
53

64
/**
75
* Created by ben on 8/04/17.
86
*/
97
public class SequentialForwardSelection extends FeatureSelection {
108

11-
public Set<Integer> select(final Set<Instance> instances, int numFeaturesToSelect) {
12-
// In this case we have no data to use, so return the empty set
13-
if(instances == null || instances.isEmpty()) return new HashSet<Integer>();
14-
15-
// Extract an instance to check the amount of features, assumes all instances have same # of features
16-
Instance sampleInstance = instances.iterator().next();
17-
int totalFeatures = sampleInstance.getNumFeatures();
18-
19-
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
20-
Set<Integer> remainingFeatures = IntStream.rangeClosed(0, totalFeatures - 1)
21-
.boxed().collect(Collectors.toSet());
22-
23-
// Nothing we can do if the number of features to select is greater than or equal to the total size
24-
if (numFeaturesToSelect >= totalFeatures){
25-
return remainingFeatures;
26-
}
27-
28-
// Subset of only selected features indices
29-
Set<Integer> selectedFeatures = new HashSet<>();
30-
31-
Classifier classifier = new Classifier(instances);
32-
33-
while (selectedFeatures.size() < numFeaturesToSelect){
34-
int feature = best(classifier, selectedFeatures, remainingFeatures);
35-
36-
// No more valid features
37-
if (feature == -1) break;
9+
public SequentialForwardSelection(Set<Instance> instances){
10+
super(instances);
11+
}
3812

39-
selectedFeatures.add(feature);
40-
// Remove the feature so we do not keep selecting the same one
41-
remainingFeatures.remove(feature);
42-
}
13+
public Set<Integer> select(int numFeaturesToSelect) {
14+
return select((features, size) -> size < numFeaturesToSelect);
15+
}
4316

44-
compareAccuracy(classifier, selectedFeatures, totalFeatures);
45-
return selectedFeatures;
17+
public Set<Integer> select(double goalAccuracy) {
18+
return select((features, size) -> objectiveFunction( features) < goalAccuracy);
4619
}
4720

48-
public Set<Integer> select(final Set<Instance> instances, double goalAccuracy) {
21+
public Set<Integer> select(Criteria criteria) {
4922
// In this case we have no data to use, so return the empty set
5023
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
5124

52-
// Extract an instance to check the amount of features, assumes all instances have same # of features
53-
Instance sampleInstance = instances.iterator().next();
54-
int totalFeatures = sampleInstance.getNumFeatures();
55-
5625
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
57-
Set<Integer> remainingFeatures = IntStream.rangeClosed(0, totalFeatures - 1)
58-
.boxed().collect(Collectors.toSet());
26+
Set<Integer> remainingFeatures = getFeatures();
5927

6028
// Subset of only selected features indices
6129
Set<Integer> selectedFeatures = new HashSet<>();
6230

63-
// Track classifiction accuracy
64-
double accuracy = 0;
65-
6631
Classifier classifier = new Classifier(instances);
6732

68-
while (accuracy < goalAccuracy){
69-
int feature = best(classifier, selectedFeatures, remainingFeatures);
33+
while (criteria.evaluate(selectedFeatures, selectedFeatures.size())){
34+
int feature = best(selectedFeatures, remainingFeatures);
7035
// No more valid features
7136
if (feature == -1) break;
7237

7338
selectedFeatures.add(feature);
7439
// Remove the feature so we do not keep selecting the same one
7540
remainingFeatures.remove(feature);
76-
77-
accuracy = objectiveFunction(classifier, selectedFeatures);
7841
}
7942

80-
compareAccuracy(classifier, selectedFeatures, totalFeatures);
8143
return selectedFeatures;
8244
}
8345

8446

85-
86-
8747
}

0 commit comments

Comments
 (0)