Skip to content

jesper-olsen/ff-py

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

45 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Forward-Forward

The Forward-Forward algorithm [1,2] evaluated on the MNIST handwritten digit recognition task; Python implementation following [3]. Two versions are implemented: one with numpy and one with jax.

Network: 5-layers; input layer with 784 (28x28) states, 3 hidden layers with 1000 states each and an output layer with 10 states corresponding to the 10 digits.

Data partitioning: Train on 50k samples, validate on 10k, test on 10k.

Error rate (energy / softmax):

  • Hinton's matlab code [2]: 1.44% / 1.47%
  • This repo (numpy): 1.33% / 1.38%
  • This repo (jax) :1.30% / 1.45%

See the Deep Boltzmann Machines repo for another result on the same task.

References:

[1] The Forward-Forward Algorithm: Some Preliminary Investigations, Geoffrey Hinton, NeurIPS 2022
[2] Hinton's NIPS'22 Talk
[3] Hinton's matlab code
[4] Hinton's preprocessed MNIST db
[5] LeCun's raw MNIST db

Run:

Download MNIST - either [5] or [4]; Edit mnist.py if [4]:

% mkdir -p MNIST/raw
% cd MNIST
% wget https://www.cs.toronto.edu/~hinton/mnistdata.mat
% cd raw
% wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
% wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
% wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
% wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
% gunzip *.gz
% cd ../..

Install dependencies

% uv venv
% uv pip install numpy jax matplotlib

Train a model - run either main.py (numpy version) or main_jax.py.

% time uv run main.py
Batchsize: 100 Input-dim: 784 #training batches: 500
states per layer:  [784, 1000, 1000, 1000, 10]
ep:   0 gain: 1.000 trainlogcost: 61.8054 PairwiseErrs: 5921, 5906, 6414
ep:   1 gain: 1.000 trainlogcost: 28.5926 PairwiseErrs: 2760, 2562, 3282
ep:   2 gain: 1.000 trainlogcost: 20.6733 PairwiseErrs: 2203, 1879, 2324
ep:   3 gain: 1.000 trainlogcost: 16.5810 PairwiseErrs: 1866, 1595, 1991
ep:   4 gain: 1.000 trainlogcost: 13.0995 PairwiseErrs: 1682, 1321, 1681
Energy-based errs: Train 2127/50000 Valid 438/10000
Softmax-based errs: Valid 370/10000
rms:  0.0403, 0.0361, 0.0409
suprms:  0.0403, 0.0218
ep:   5 gain: 1.000 trainlogcost: 11.1300 PairwiseErrs: 1497, 1138, 1413
ep:   6 gain: 1.000 trainlogcost: 9.4637 PairwiseErrs: 1321, 971, 1301
ep:   7 gain: 1.000 trainlogcost: 8.7353 PairwiseErrs: 1206, 902, 1187
ep:   8 gain: 1.000 trainlogcost: 7.9537 PairwiseErrs: 1179, 784, 1096
ep:   9 gain: 1.000 trainlogcost: 6.9619 PairwiseErrs: 1143, 750, 957
Energy-based errs: Train 1268/50000 Valid 343/10000
Softmax-based errs: Valid 285/10000
rms:  0.0418, 0.0363, 0.0408
suprms:  0.0321, 0.0195
ep:  10 gain: 1.000 trainlogcost: 6.0931 PairwiseErrs: 1044, 633, 882
ep:  11 gain: 1.000 trainlogcost: 5.3677 PairwiseErrs: 960, 624, 827
ep:  12 gain: 1.000 trainlogcost: 4.9876 PairwiseErrs: 894, 592, 836
ep:  13 gain: 1.000 trainlogcost: 4.7329 PairwiseErrs: 862, 525, 752
ep:  14 gain: 1.000 trainlogcost: 4.3135 PairwiseErrs: 821, 492, 694
Energy-based errs: Train 804/50000 Valid 270/10000
Softmax-based errs: Valid 249/10000
rms:  0.0424, 0.0363, 0.0404
suprms:  0.0271, 0.0197
ep:  15 gain: 1.000 trainlogcost: 3.8409 PairwiseErrs: 789, 447, 638
ep:  16 gain: 1.000 trainlogcost: 3.4593 PairwiseErrs: 755, 401, 607
ep:  17 gain: 1.000 trainlogcost: 3.4946 PairwiseErrs: 706, 407, 562
ep:  18 gain: 1.000 trainlogcost: 3.3885 PairwiseErrs: 679, 365, 572
ep:  19 gain: 1.000 trainlogcost: 2.9400 PairwiseErrs: 623, 349, 512
Energy-based errs: Train 507/50000 Valid 229/10000
Softmax-based errs: Valid 215/10000
rms:  0.0429, 0.0358, 0.0398
suprms:  0.0238, 0.0196
ep:  20 gain: 1.000 trainlogcost: 2.4612 PairwiseErrs: 586, 286, 430
ep:  21 gain: 1.000 trainlogcost: 2.6846 PairwiseErrs: 566, 310, 416
ep:  22 gain: 1.000 trainlogcost: 2.4294 PairwiseErrs: 570, 265, 441
ep:  23 gain: 1.000 trainlogcost: 2.0932 PairwiseErrs: 524, 247, 404
ep:  24 gain: 1.000 trainlogcost: 2.0991 PairwiseErrs: 539, 230, 372
Energy-based errs: Train 561/50000 Valid 228/10000
Softmax-based errs: Valid 237/10000
rms:  0.0432, 0.0352, 0.0391
suprms:  0.0217, 0.0193
ep:  25 gain: 1.000 trainlogcost: 1.9902 PairwiseErrs: 508, 229, 392
ep:  26 gain: 1.000 trainlogcost: 1.8399 PairwiseErrs: 473, 188, 338
ep:  27 gain: 1.000 trainlogcost: 1.8218 PairwiseErrs: 497, 197, 327
ep:  28 gain: 1.000 trainlogcost: 1.6203 PairwiseErrs: 439, 184, 292
ep:  29 gain: 1.000 trainlogcost: 1.6777 PairwiseErrs: 439, 168, 332
Energy-based errs: Train 300/50000 Valid 178/10000
Softmax-based errs: Valid 194/10000
rms:  0.0433, 0.0347, 0.0384
suprms:  0.0202, 0.0191
ep:  30 gain: 1.000 trainlogcost: 1.5999 PairwiseErrs: 438, 164, 253
ep:  31 gain: 1.000 trainlogcost: 1.6148 PairwiseErrs: 434, 170, 278
ep:  32 gain: 1.000 trainlogcost: 1.3281 PairwiseErrs: 417, 147, 280
ep:  33 gain: 1.000 trainlogcost: 1.2226 PairwiseErrs: 373, 141, 285
ep:  34 gain: 1.000 trainlogcost: 1.3062 PairwiseErrs: 397, 158, 251
Energy-based errs: Train 301/50000 Valid 205/10000
Softmax-based errs: Valid 209/10000
rms:  0.0433, 0.0341, 0.0376
suprms:  0.0190, 0.0186
ep:  35 gain: 1.000 trainlogcost: 1.3162 PairwiseErrs: 362, 131, 250
ep:  36 gain: 1.000 trainlogcost: 1.1939 PairwiseErrs: 398, 132, 265
ep:  37 gain: 1.000 trainlogcost: 1.2284 PairwiseErrs: 357, 106, 221
ep:  38 gain: 1.000 trainlogcost: 1.2620 PairwiseErrs: 401, 105, 207
ep:  39 gain: 1.000 trainlogcost: 1.0853 PairwiseErrs: 353, 109, 199
Energy-based errs: Train 227/50000 Valid 196/10000
Softmax-based errs: Valid 177/10000
rms:  0.0435, 0.0333, 0.0368
suprms:  0.0181, 0.0184
ep:  40 gain: 1.000 trainlogcost: 0.9844 PairwiseErrs: 348, 106, 206
ep:  41 gain: 1.000 trainlogcost: 1.0729 PairwiseErrs: 332, 110, 205
ep:  42 gain: 1.000 trainlogcost: 1.1798 PairwiseErrs: 366, 120, 191
ep:  43 gain: 1.000 trainlogcost: 1.1198 PairwiseErrs: 370, 109, 200
ep:  44 gain: 1.000 trainlogcost: 1.0763 PairwiseErrs: 348, 104, 205
Energy-based errs: Train 234/50000 Valid 182/10000
Softmax-based errs: Valid 184/10000
rms:  0.0440, 0.0326, 0.0361
suprms:  0.0175, 0.0181
ep:  45 gain: 1.000 trainlogcost: 0.9911 PairwiseErrs: 308, 82, 165
ep:  46 gain: 1.000 trainlogcost: 1.0948 PairwiseErrs: 335, 76, 195
ep:  47 gain: 1.000 trainlogcost: 1.0943 PairwiseErrs: 334, 98, 200
ep:  48 gain: 1.000 trainlogcost: 0.9817 PairwiseErrs: 322, 96, 178
ep:  49 gain: 1.000 trainlogcost: 1.0245 PairwiseErrs: 358, 100, 193
Energy-based errs: Train 159/50000 Valid 191/10000
Softmax-based errs: Valid 189/10000
rms:  0.0444, 0.0319, 0.0354
suprms:  0.0172, 0.0177
ep:  50 gain: 1.000 trainlogcost: 0.9921 PairwiseErrs: 358, 82, 172
ep:  51 gain: 1.000 trainlogcost: 0.8527 PairwiseErrs: 295, 61, 156
ep:  52 gain: 1.000 trainlogcost: 0.8655 PairwiseErrs: 319, 89, 167
ep:  53 gain: 1.000 trainlogcost: 1.0242 PairwiseErrs: 323, 86, 151
ep:  54 gain: 1.000 trainlogcost: 0.9642 PairwiseErrs: 306, 90, 156
Energy-based errs: Train 132/50000 Valid 175/10000
Softmax-based errs: Valid 172/10000
rms:  0.0446, 0.0312, 0.0347
suprms:  0.0170, 0.0176
ep:  55 gain: 1.000 trainlogcost: 0.9050 PairwiseErrs: 298, 77, 160
ep:  56 gain: 1.000 trainlogcost: 0.9185 PairwiseErrs: 306, 78, 180
ep:  57 gain: 1.000 trainlogcost: 0.7065 PairwiseErrs: 277, 50, 142
ep:  58 gain: 1.000 trainlogcost: 0.8138 PairwiseErrs: 287, 47, 165
ep:  59 gain: 1.000 trainlogcost: 0.7789 PairwiseErrs: 249, 49, 136
Energy-based errs: Train 173/50000 Valid 179/10000
Softmax-based errs: Valid 179/10000
rms:  0.0445, 0.0304, 0.0341
suprms:  0.0166, 0.0173
ep:  60 gain: 1.000 trainlogcost: 0.7297 PairwiseErrs: 240, 42, 123
ep:  61 gain: 1.000 trainlogcost: 0.7902 PairwiseErrs: 250, 68, 139
ep:  62 gain: 1.000 trainlogcost: 0.7796 PairwiseErrs: 252, 51, 110
ep:  63 gain: 1.000 trainlogcost: 0.7372 PairwiseErrs: 219, 51, 130
ep:  64 gain: 0.984 trainlogcost: 0.6993 PairwiseErrs: 247, 52, 111
Energy-based errs: Train 123/50000 Valid 183/10000
Softmax-based errs: Valid 182/10000
rms:  0.0444, 0.0297, 0.0334
suprms:  0.0162, 0.0171
ep:  65 gain: 0.968 trainlogcost: 0.7500 PairwiseErrs: 247, 58, 134
ep:  66 gain: 0.952 trainlogcost: 0.6869 PairwiseErrs: 211, 61, 130
ep:  67 gain: 0.936 trainlogcost: 0.6493 PairwiseErrs: 227, 44, 100
ep:  68 gain: 0.920 trainlogcost: 0.6234 PairwiseErrs: 211, 39, 105
ep:  69 gain: 0.904 trainlogcost: 0.5912 PairwiseErrs: 228, 35, 93
Energy-based errs: Train 82/50000 Valid 177/10000
Softmax-based errs: Valid 172/10000
rms:  0.0443, 0.0290, 0.0327
suprms:  0.0157, 0.0168
ep:  70 gain: 0.888 trainlogcost: 0.5819 PairwiseErrs: 215, 41, 92
ep:  71 gain: 0.872 trainlogcost: 0.6130 PairwiseErrs: 215, 34, 97
ep:  72 gain: 0.856 trainlogcost: 0.5644 PairwiseErrs: 191, 39, 94
ep:  73 gain: 0.840 trainlogcost: 0.4894 PairwiseErrs: 198, 32, 73
ep:  74 gain: 0.824 trainlogcost: 0.5247 PairwiseErrs: 194, 28, 68
Energy-based errs: Train 38/50000 Valid 161/10000
Softmax-based errs: Valid 173/10000
rms:  0.0441, 0.0282, 0.0319
suprms:  0.0152, 0.0164
ep:  75 gain: 0.808 trainlogcost: 0.4815 PairwiseErrs: 141, 16, 47
ep:  76 gain: 0.792 trainlogcost: 0.4650 PairwiseErrs: 146, 12, 51
ep:  77 gain: 0.776 trainlogcost: 0.4523 PairwiseErrs: 150, 18, 50
ep:  78 gain: 0.760 trainlogcost: 0.4682 PairwiseErrs: 154, 13, 42
ep:  79 gain: 0.744 trainlogcost: 0.4587 PairwiseErrs: 166, 14, 44
Energy-based errs: Train 40/50000 Valid 173/10000
Softmax-based errs: Valid 165/10000
rms:  0.0435, 0.0274, 0.0311
suprms:  0.0148, 0.0162
ep:  80 gain: 0.728 trainlogcost: 0.4260 PairwiseErrs: 138, 2, 35
ep:  81 gain: 0.712 trainlogcost: 0.3953 PairwiseErrs: 119, 6, 28
ep:  82 gain: 0.696 trainlogcost: 0.4259 PairwiseErrs: 117, 7, 24
ep:  83 gain: 0.680 trainlogcost: 0.4074 PairwiseErrs: 113, 2, 31
ep:  84 gain: 0.664 trainlogcost: 0.4053 PairwiseErrs: 113, 8, 37
Energy-based errs: Train 12/50000 Valid 165/10000
Softmax-based errs: Valid 163/10000
rms:  0.0427, 0.0266, 0.0303
suprms:  0.0145, 0.0159
ep:  85 gain: 0.648 trainlogcost: 0.4272 PairwiseErrs: 85, 3, 22
ep:  86 gain: 0.632 trainlogcost: 0.3929 PairwiseErrs: 99, 5, 21
ep:  87 gain: 0.616 trainlogcost: 0.4033 PairwiseErrs: 89, 5, 17
ep:  88 gain: 0.600 trainlogcost: 0.4061 PairwiseErrs: 90, 1, 12
ep:  89 gain: 0.584 trainlogcost: 0.4070 PairwiseErrs: 96, 2, 16
Energy-based errs: Train 9/50000 Valid 151/10000
Softmax-based errs: Valid 168/10000
rms:  0.0418, 0.0259, 0.0295
suprms:  0.0146, 0.0161
ep:  90 gain: 0.568 trainlogcost: 0.4083 PairwiseErrs: 77, 2, 16
ep:  91 gain: 0.552 trainlogcost: 0.3981 PairwiseErrs: 66, 3, 14
ep:  92 gain: 0.536 trainlogcost: 0.4081 PairwiseErrs: 81, 2, 11
ep:  93 gain: 0.520 trainlogcost: 0.3977 PairwiseErrs: 68, 1, 7
ep:  94 gain: 0.504 trainlogcost: 0.3725 PairwiseErrs: 77, 2, 7
Energy-based errs: Train 8/50000 Valid 156/10000
Softmax-based errs: Valid 161/10000
rms:  0.0411, 0.0253, 0.0289
suprms:  0.0146, 0.0163
ep:  95 gain: 0.488 trainlogcost: 0.3916 PairwiseErrs: 71, 1, 5
ep:  96 gain: 0.472 trainlogcost: 0.3894 PairwiseErrs: 61, 1, 6
ep:  97 gain: 0.456 trainlogcost: 0.3923 PairwiseErrs: 75, 0, 4
ep:  98 gain: 0.440 trainlogcost: 0.4035 PairwiseErrs: 61, 0, 5
ep:  99 gain: 0.424 trainlogcost: 0.3927 PairwiseErrs: 58, 0, 2
Energy-based errs: Train 4/50000 Valid 146/10000
Softmax-based errs: Valid 162/10000
rms:  0.0404, 0.0247, 0.0283
suprms:  0.0147, 0.0165
ep: 100 gain: 0.408 trainlogcost: 0.3968 PairwiseErrs: 60, 0, 2
ep: 101 gain: 0.392 trainlogcost: 0.3965 PairwiseErrs: 36, 0, 2
ep: 102 gain: 0.376 trainlogcost: 0.3948 PairwiseErrs: 57, 0, 4
ep: 103 gain: 0.360 trainlogcost: 0.3887 PairwiseErrs: 50, 0, 2
ep: 104 gain: 0.344 trainlogcost: 0.3997 PairwiseErrs: 47, 0, 4
Energy-based errs: Train 2/50000 Valid 148/10000
Softmax-based errs: Valid 170/10000
rms:  0.0397, 0.0243, 0.0278
suprms:  0.0147, 0.0166
ep: 105 gain: 0.328 trainlogcost: 0.4037 PairwiseErrs: 36, 0, 0
ep: 106 gain: 0.312 trainlogcost: 0.4162 PairwiseErrs: 46, 0, 0
ep: 107 gain: 0.296 trainlogcost: 0.4211 PairwiseErrs: 39, 0, 2
ep: 108 gain: 0.280 trainlogcost: 0.4224 PairwiseErrs: 40, 0, 0
ep: 109 gain: 0.264 trainlogcost: 0.4185 PairwiseErrs: 35, 0, 0
Energy-based errs: Train 1/50000 Valid 138/10000
Softmax-based errs: Valid 161/10000
rms:  0.0392, 0.0240, 0.0274
suprms:  0.0149, 0.0169
ep: 110 gain: 0.248 trainlogcost: 0.4191 PairwiseErrs: 34, 0, 4
ep: 111 gain: 0.232 trainlogcost: 0.4227 PairwiseErrs: 35, 0, 0
ep: 112 gain: 0.216 trainlogcost: 0.4165 PairwiseErrs: 32, 0, 1
ep: 113 gain: 0.200 trainlogcost: 0.4290 PairwiseErrs: 29, 0, 0
ep: 114 gain: 0.184 trainlogcost: 0.4280 PairwiseErrs: 34, 0, 0
Energy-based errs: Train 1/50000 Valid 145/10000
Softmax-based errs: Valid 163/10000
rms:  0.0389, 0.0237, 0.0271
suprms:  0.0151, 0.0170
ep: 115 gain: 0.168 trainlogcost: 0.4329 PairwiseErrs: 34, 0, 0
ep: 116 gain: 0.152 trainlogcost: 0.4330 PairwiseErrs: 28, 0, 0
ep: 117 gain: 0.136 trainlogcost: 0.4374 PairwiseErrs: 23, 0, 0
ep: 118 gain: 0.120 trainlogcost: 0.4405 PairwiseErrs: 23, 0, 1
ep: 119 gain: 0.104 trainlogcost: 0.4478 PairwiseErrs: 24, 0, 0
Energy-based errs: Train 1/50000 Valid 150/10000
Softmax-based errs: Valid 163/10000
rms:  0.0386, 0.0236, 0.0270
suprms:  0.0152, 0.0172
ep: 120 gain: 0.088 trainlogcost: 0.4468 PairwiseErrs: 13, 0, 0
ep: 121 gain: 0.072 trainlogcost: 0.4476 PairwiseErrs: 21, 0, 0
ep: 122 gain: 0.056 trainlogcost: 0.4467 PairwiseErrs: 14, 0, 0
ep: 123 gain: 0.040 trainlogcost: 0.4450 PairwiseErrs: 23, 0, 0
ep: 124 gain: 0.024 trainlogcost: 0.4454 PairwiseErrs: 19, 0, 0
Energy-based errs: Train 1/50000 Valid 150/10000
Softmax-based errs: Valid 166/10000
rms:  0.0385, 0.0235, 0.0269
suprms:  0.0152, 0.0172
Energy-based errs: Train 1/50000 Test 133/10000
Softmax-based errs: Train 1/50000 Test 138/10000

real	42m1.433s
user	30m12.164s
sys	5m10.400s

% sysctl -a | grep machdep.cpu
machdep.cpu.core_count: 8
machdep.cpu.brand_string: Apple M1

About

Forward-Forward algorithm on MNIST

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages