-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex.js
66 lines (56 loc) · 1.28 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import Layer from './layer';
import Model from './model';
import { getTrainData, getTestData } from './data';
import './index.css';
const l0 = new Layer({
units: 2
});
const l1 = new Layer({
activation: 'tanh',
units: 3
});
const l2 = new Layer({
activation: 'tanh',
units: 3
});
const l3 = new Layer({
activation: 'sigmoid',
units: 1
});
const model = new Model([l0, l1, l2, l3]);
let step = 2;
for (let i = 0; i < 8000; i ++) {
if (i % 1000 === 0) {
step /= 2;
}
model.fit(...getTrainData(), { step });
}
const canvas = document.getElementById('canvas');
const width = canvas.width;
const height = canvas.height;
const ctx = canvas.getContext('2d');
function drawPoint(x, y, color) {
ctx.beginPath();
ctx.arc(x, y, 3, 0, Math.PI * 2, true);
ctx.fillStyle = color;
ctx.fill();
}
function test() {
const [inputs, labels] = getTestData();
for (let i = 0; i < inputs.length; i ++) {
const input = inputs[i];
const label = labels[i];
// const output = model.predict(input)[0];
const output = model.predict(input)[0];
const x1 = input[0] * width;
const x2 = input[1] * height;
let color;
if (output <= 0.5) {
color = `rgb(255, 0, 0)`;
} else {
color = `rgb(0, 0, 255)`;
}
drawPoint(x1, x2, color);
}
}
test();