Skip to content

Commit 2f5c851

Browse files
committed
修改导函数
1 parent 8e6655e commit 2f5c851

File tree

13 files changed

+147
-47
lines changed

13 files changed

+147
-47
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# 简单的全连接神经网络实现
22

3-
优化器使用批量梯度下降,损失函数使用交叉熵函数
3+
优化器使用批量梯度下降,损失函数使用交叉熵函数,支持sigmoid与tanh两种激活函数
44

55
<img src="https://raw.githubusercontent.com/lrenc/FoolNet/master/foolnet.png" width="50%">
6+
7+
```javascript
8+
function fn(x1, x2) {
9+
let y = 0;
10+
if (x1 > 0.5 && x2 > 0.5) {
11+
y = 1;
12+
} else if (x1 < 0.5 && x2 < 0.5) {
13+
y = 1;
14+
}
15+
return y;
16+
}
17+
```
18+
19+
<img src="https://raw.githubusercontent.com/lrenc/FoolNet/master/result.png" width="50%">

index.html

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
<title>Foo Net</title>
66
</head>
77
<body>
8-
<div id="root"></div>
8+
<div id="root">
9+
<canvas id="canvas" width="300" height="300"></canvas>
10+
</div>
911
</body>
1012
</html>

result.png

197 KB
Loading

src/activation.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
12
export function sigmoid(x) {
23
return 1 / (1 + Math.exp(-x));
34
}
5+
6+
export function tanh(x) {
7+
return (Math.exp(x) - Math.exp(-x)) / (Math.exp(x) + Math.exp(-x))
8+
}

src/data.js

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11

2+
function fn(x1, x2) {
3+
let y = 0;
4+
if (x1 > 0.5 && x2 > 0.5) {
5+
y = 1;
6+
} else if (x1 < 0.5 && x2 < 0.5) {
7+
y = 1;
8+
}
9+
return y;
10+
}
11+
212
export function getTrainData() {
313
const inputs = [];
414
const labels = [];
5-
for (let i = 0; i < 300; i ++) {
15+
for (let i = 0; i < 100; i ++) {
16+
let x1 = Math.random();
17+
let x2 = Math.random();
18+
let y = fn(x1, x2);
19+
inputs.push([x1, x2]);
20+
labels.push([y]);
21+
}
22+
return [inputs, labels];
23+
}
24+
25+
export function getTestData() {
26+
const inputs = [];
27+
const labels = [];
28+
for (let i = 0; i < 8000; i ++) {
629
let x1 = Math.random();
730
let x2 = Math.random();
8-
let y = 0;
9-
if (x1 <= 0.5 && x2 <= 0.5) {
10-
y = 1;
11-
}
31+
let y = fn(x1, x2);
1232
inputs.push([x1, x2]);
1333
labels.push([y]);
1434
}

src/derivative.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// 激活函数的导函数
2+
export function sigmoid(a) {
3+
return a * (1 - a);
4+
}
5+
6+
export function tanh(a) {
7+
return 1 - Math.pow(a, 2);
8+
}

src/index.css

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
canvas {
2+
width: 150px;
3+
height: 150px;
4+
}

src/index.js

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
import Layer from './layer';
22
import Model from './Model';
3-
import { getTrainData } from './data';
4-
import './test';
3+
import { getTrainData, getTestData } from './data';
4+
import './index.css';
55

66
const l0 = new Layer({
7-
type: 'inputLayer',
87
units: 2
98
});
109

1110
const l1 = new Layer({
12-
activation: 'sigmoid',
11+
activation: 'tanh',
1312
units: 3
1413
});
1514

1615
const l2 = new Layer({
17-
activation: 'sigmoid',
16+
activation: 'tanh',
1817
units: 3
1918
});
2019

@@ -26,16 +25,42 @@ const l3 = new Layer({
2625
const model = new Model([l0, l1, l2, l3]);
2726

2827
let step = 2;
29-
for (let i = 0; i < 5000; i ++) {
28+
for (let i = 0; i < 8000; i ++) {
3029
if (i % 1000 === 0) {
3130
step /= 2;
3231
}
3332
model.fit(...getTrainData(), { step });
3433
}
35-
console.log(model.predict([0.1, 0.1])); // 1
36-
console.log(model.predict([0.2, 0.2])); // 1
37-
console.log(model.predict([0.3, 0.3])); // 1
38-
console.log(model.predict([0.9, 0.1])); // 0
39-
console.log(model.predict([0.1, 0.9])); // 0
40-
console.log(model.predict([0.6, 0.6])); // 0
41-
console.log(model.predict([0.9, 0.9])); // 0
34+
35+
const canvas = document.getElementById('canvas');
36+
const width = canvas.width;
37+
const height = canvas.height;
38+
const ctx = canvas.getContext('2d');
39+
40+
function drawPoint(x, y, color) {
41+
ctx.beginPath();
42+
ctx.arc(x, y, 3, 0, Math.PI * 2, true);
43+
ctx.fillStyle = color;
44+
ctx.fill();
45+
}
46+
47+
function test() {
48+
const [inputs, labels] = getTestData();
49+
for (let i = 0; i < inputs.length; i ++) {
50+
const input = inputs[i];
51+
const label = labels[i];
52+
// const output = model.predict(input)[0];
53+
const output = model.predict(input)[0];
54+
const x1 = input[0] * width;
55+
const x2 = input[1] * height;
56+
let color;
57+
if (output <= 0.5) {
58+
color = `rgb(255, 0, 0)`;
59+
} else {
60+
color = `rgb(0, 0, 255)`;
61+
}
62+
drawPoint(x1, x2, color);
63+
}
64+
}
65+
66+
test();

src/layer.js

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
import Neure from './neure';
2-
import * as fn from './activation';
3-
4-
const INPUT_LAYER = 'inputLayer';
52

63
export default class Layer {
7-
84
constructor(params) {
95
const { units, activation, type } = params;
106
this.units = units;
117
this.neures = Array(units).fill(null);
12-
this.activation = null;
13-
if (fn[activation]) {
14-
this.activation = fn[activation];
15-
}
16-
if (type === INPUT_LAYER) {
17-
this.neures = this.neures.map(() => {
18-
return new Neure(1, null, true);
19-
});
20-
}
8+
this.activation = activation;
9+
}
10+
11+
init() {
12+
this.neures = this.neures.map(() => {
13+
return new Neure(1, null, true);
14+
});
2115
}
2216

2317
apply(layer) {

src/model.js

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11

2-
import { dot, sub, sum, power, last, transpose, div, mul } from './util';
2+
import Layer from './layer';
3+
import * as derivate from './derivative';
4+
import { dot, sub, sum, power, transpose, div, mul, add } from './util';
35

46
export default class Network {
57
constructor(layers) {
8+
const layer = layers[0];
9+
layer.init();
610
for (let i = 0; i < layers.length - 1; i ++) {
711
const curr = layers[i];
812
const next = layers[i + 1];
@@ -43,17 +47,28 @@ export default class Network {
4347
const size = inputs.length;
4448
let layers = this.layers.slice();
4549
let p = this.multiPredict(inputs);
46-
let dz = sub(p, labels); // 二维数组
4750
let lastLayer = layers.pop();
51+
let activation = lastLayer.activation;
52+
let dz;
53+
if (activation === 'sigmoid') {
54+
dz = sub(p, labels); // 二维数组
55+
} else {
56+
dz = div(mul(sub(p, labels), add(1, p)), p);
57+
}
4858
p = this.multiPredict(inputs, layers);
4959
let dw = div(dot(transpose(dz), p), size); // A1
5060
let db = sum(dz) / size;
5161
dws.unshift(dw);
5262
dbs.unshift(db);
5363
while (true) {
5464
const w = lastLayer.neures.map(neure => neure.weights);
55-
dz = mul(dot(dz, w), sub(1, power(p, 2)));
5665
lastLayer = layers.pop();
66+
activation = lastLayer.activation;
67+
if (activation === 'sigmoid') {
68+
dz = mul(dot(dz, w), sub(p, power(p, 2)));
69+
} else {
70+
dz = mul(dot(dz, w), sub(1, power(p, 2)));
71+
}
5772
if (layers.length === 0) {
5873
break;
5974
}

0 commit comments

Comments
 (0)