-
Notifications
You must be signed in to change notification settings - Fork 0
/
Pso.java
104 lines (86 loc) · 2.65 KB
/
Pso.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import java.util.Arrays;
import java.util.ArrayList;
class Particle{
// Consts
public static double W;
public static double C1;
public static double C2;
public static int dataset_id;
// define 'global_best's
public static int[] global_best_x;
public static double global_best_fit;
private int dim;
private int[] x;
private double[] v;
private double fit;
private int[] best_x;
private double best_fit;
Particle(int dim){
this.dim = dim;
x = TestFunctions.init_x(dim);
v = new double[dim];
fit = TestFunctions.get_fit(dataset_id,x);
best_x = Arrays.copyOf(x,x.length);
best_fit = fit;
}
int[] get_x(){ return x; }
double get_fit(){ return fit; }
void update() {
int[] new_x = Arrays.copyOf(x,x.length);
double[] new_v = Arrays.copyOf(v,v.length);
double r1 = Math.random();
double r2 = Math.random();
new_x = TestFunctions.update_x(v);
for (int i = 0; i < dim; i++) {
new_v[i] = W * v[i] + C1 * r1 * (best_x[i] - x[i]) + C2 * r2 * (global_best_x[i] - x[i]);
}
x = new_x;
v = new_v;
fit = TestFunctions.get_fit(dataset_id,x);
}
void update_best(){
if(fit>best_fit){
best_fit = fit;
best_x = Arrays.copyOf(x,x.length);
}
if(fit>global_best_fit){
global_best_fit = fit;
global_best_x = Arrays.copyOf(x,x.length);
}
}
}
public class Pso{
public static int[] execute(int N, int T, int dataset_id, double[] params){
Particle.dataset_id = dataset_id;
Particle.W =params[0];
Particle.C1=params[1];
Particle.C2=params[2];
Particle.global_best_x = new int[N];
Particle.global_best_fit = 0;
int D = TestFunctions.get_weight(dataset_id).length;
ArrayList<Particle> particle_list = new ArrayList<Particle>();
for(int i=0;i<N;i++){
Particle part = new Particle(D);
particle_list.add(part);
particle_list.get(i).update_best();
}
for(int t=1;t<T+1;t++){
for(int i=0;i<N;i++){
particle_list.get(i).update();
particle_list.get(i).update_best();
}
}
return Particle.global_best_x;
}
public static void main(String[] args) {
// number of particles
int N = 10;
// number of iteration
int T = 10;
// parameters
double[] params ={0.3,0.6,0.6};
// dataset
int dataset_id = 4;
int[] result = execute(N,T,dataset_id,params);
}
}