-
Notifications
You must be signed in to change notification settings - Fork 0
/
NeuralNetwork.java
139 lines (125 loc) · 4.13 KB
/
NeuralNetwork.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import java.util.*;
import java.io.*;
public class NeuralNetwork implements Serializable {
private List<List<Perceptron>> network;
private double learningRate;
public NeuralNetwork(int inputs, List<Integer> sizes, double lR) {
learningRate=lR;
network = new ArrayList<List<Perceptron>>(sizes.size());
for (int i = 0; i < sizes.size(); i++) {
List<Perceptron> col = new ArrayList<Perceptron>(sizes.get(i));
if (i==0) {
for (int j = 0; j < sizes.get(i); j++) {
col.add(new Perceptron(inputs));
}
}
else {
for (int j = 0; j < sizes.get(i); j++) {
col.add(new Perceptron(sizes.get(i-1)));
}
}
network.add(col);
}
}
public void setLR(double lR) {
learningRate = lR;
}
public double getLR() {
return learningRate;
}
public List<Double> doOut(List<Double> init) {
List<Double> outArray = init;
List<Double> dummyArray;
double pass = 0;
for (int col = 0; col < network.size(); col++) {
dummyArray = new ArrayList<Double>(network.get(col).size());
for (int i = 0; i < network.get(col).size(); i++) {
dummyArray.add(network.get(col).get(i).out(outArray));
}
outArray = dummyArray;
}
return outArray;
}
public List<Double> out(List<Double> init) {
List<Double> outArray = init;
List<Double> dummyArray;
double pass = 0;
for (int col = 0; col < network.size(); col++) {
dummyArray = new ArrayList<Double>(network.get(col).size());
for (int i = 0; i < network.get(col).size(); i++) {
pass = network.get(col).get(i).out(outArray);
dummyArray.add(pass);
network.get(col).get(i).setOutput(pass);
}
outArray = dummyArray;
}
return outArray;
}
private double error(int nodeId, List<Double> init, List<Double> expected) {
return out(init).get(nodeId)-expected.get(nodeId);
}
public double totalError(List<Double> init, List<Double> expected) {
double total = 0;
List<Double> res = out(init);
for (int i = 0; i < res.size(); i++) {
total += Math.pow(res.get(i) - expected.get(i),2);
}
return total/2;
}
public void trainStep(List<Double> init, List<Double> expected) {
int col = network.size()-1;
int i = 0;
int j = 0;
int input = 0;
double pass = 0;
double totalPass = 0;
if (col>0) {
for (i=0; i<network.get(col).size(); i++) {
pass = network.get(col).get(i).getOutput();
network.get(col).get(i).setErrorSignal((pass-expected.get(i))*pass*(1-pass));
for (j=0; j < network.get(col-1).size(); j++) {
network.get(col).get(i).changeWeight(j,-learningRate*network.get(col).get(i).getErrorSignal()*network.get(col-1).get(j).getOutput());
}
}
for (col--; col>0; col--) {
for (i=0; i<network.get(col).size(); i++) {
input = network.get(col).get(i).getWeights().size();
pass = network.get(col).get(i).getOutput();
totalPass = 0;
for (j=0; j < network.get(col+1).size(); j++) {
totalPass += network.get(col+1).get(j).getErrorSignal()*network.get(col+1).get(j).getWeight(i);
}
network.get(col).get(i).setErrorSignal(totalPass*pass*(1-pass));
for (j=0; j < input; j++) {
network.get(col).get(i).changeWeight(j,-learningRate*network.get(col).get(i).getErrorSignal()*network.get(col-1).get(j).getOutput());
}
}
}
}
for (i=0; i < network.get(0).size(); i++) {
input = network.get(0).get(i).getWeights().size();
pass = network.get(col).get(i).getOutput();
totalPass = 0;
for (j=0; j < network.get(1).size(); j++) {
totalPass += network.get(1).get(j).getErrorSignal()*network.get(1).get(j).getWeight(i);
}
network.get(0).get(i).setErrorSignal(totalPass*pass*(1-pass));
for (j=0; j<input; j++) {
network.get(0).get(i).changeWeight(j,-learningRate*network.get(0).get(i).getErrorSignal()*init.get(j));
}
}
}
public void train(List<List<Double>> inputs, List<List<Double>> expected) {
for (int i = 0; i < inputs.size(); i++) {
out(inputs.get(i));
trainStep(inputs.get(i), expected.get(i));
}
}
private double phiprime(double x) {
return -Math.exp(-x)/Math.pow(1+Math.exp(-x),2);
}
public void printTest() {
System.out.println(learningRate);
for (int i = 0; i < network.size(); i++) System.out.println(network.get(i).size());
}
}