-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmpn.py
More file actions
813 lines (643 loc) · 34.8 KB
/
mpn.py
File metadata and controls
813 lines (643 loc) · 34.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
import torch
from torch import nn
from torch.utils.data import TensorDataset
import torch.nn.functional as F
from torch.nn.init import orthogonal_
import math
from net_helpers import BaseNetwork, BaseNetworkFunctions
from net_helpers import rand_weight_init, get_activation_function
import numpy as np
import copy
import time
class MultiPlasticLayer(BaseNetworkFunctions):
"""
Fully-connected layer with multi-plasticity.
This functions very similarly to a fully connected PyTorch layer. However,
in addition to the implementation of the layer in a regular forward pass,
the modulations need to be correctly kept track of through the "reset_state"
and "update_M_matrix" functions. The latter needs to be called after every
forward pass where the modulations should be updated.
"""
def __init__(self, ml_params, output_matrix, verbose=True):
super().__init__()
init_string=''
self.verbose=verbose
# Name appended to various parameters to disti
self.mp_layer_name = ml_params.get('mpl_name', '')
self.n_input = ml_params['n_input']
self.n_output = ml_params['n_output']
init_string += ' MP Layer{} parameters:\n'.format(self.mp_layer_name)
init_string += ' n_neurons - input: {}, output: {}'.format(
self.n_input, self.n_output
)
# Determines whether or not layer weights are trainable parameters
self.layer_bias = ml_params.get('bias', True)
self.freeze_layer = ml_params.get('freeze_layer', False)
if self.freeze_layer:
init_string += ' W: Frozen // '
if self.layer_bias:
init_string += 'b: Frozen //'
else:
self.params = ['W',] # This is a local params list that can be merged with full network params if needed
if self.layer_bias:
self.params.append('b')
### Weight/bias initialization ###
# Input weights
self.W_init = ml_params.get('W_init', 'xavier')
W_freeze = ml_params.get('W_freeze', False)
# Initialize the weight tensor once.
W_tensor = torch.tensor(
rand_weight_init(self.n_input, self.n_output, init_type=self.W_init, cell_types=None),
dtype=torch.float
)
if W_freeze:
print("MPN Layer W Frozen")
self.register_buffer('W', W_tensor)
else:
self.parameter_or_buffer('W', W_tensor)
# Bias term
if self.layer_bias:
self.b_init = 'gaussian'
else:
self.b_init = 'zeros'
self.parameter_or_buffer('b', torch.tensor(
rand_weight_init(self.n_output, init_type=self.b_init),
dtype=torch.float)
)
###### M matrix-related specs ########
init_string += '\n M matrix parameters:'
self.mp_type = ml_params.get('mp_type', 'mult')
# Controls the update equation of the M matrix (calculation of \Delta M)
self.m_update_type = ml_params.get('m_update_type', 'hebb_assoc')
# Activation function to pass M through after update (can enforce bounds)
self.m_act = ml_params.get('m_activation', 'linear')
self.m_act_fn, self.m_act_fn_np, self.m_act_fn_p = get_activation_function(self.m_act)
# Initial modulation values
self.register_buffer('M_init', torch.zeros((self.n_output, self.n_input,), dtype=torch.float))
# Controls maximum and minimum values of modulations so weights don't change signs
self.modulation_bounds = ml_params.get('modulation_bounds', True)
if self.modulation_bounds:
self.M_bound_vals = ml_params.get('m_bounds', (-1.0, 1.0,)) # (min, max)
M_bounds, init_string = self.build_M_bounds(init_string=init_string)
self.register_buffer('M_bounds', M_bounds)
# These bounds will need to be continually updated if W is variable and M is additive, which is not yet implemented
if self.mp_type == 'add' and 'W' in self.params:
raise NotImplementedError('Need to continuously update bounds in this case.')
init_string += ' type: {} // Update - type: {} // Act fn: {}'.format(
self.mp_type, self.m_update_type, self.m_act
)
### Eta dependencies ###
self.eta_train = ml_params.get('eta_train', True)
eta_train_str = 'fixed'
if self.eta_train:
self.params.append('eta')
eta_train_str = 'train'
self.eta_type = ml_params.get('eta_type', 'scalar')
self.eta_init = ml_params.get('eta_init', 'eta_clamp')
self.eta_clamp = ml_params.get('eta_clamp', 1.00)
self.parameter_or_buffer('eta', torch.tensor(
self.init_M_parameter(param_type=self.eta_type, init_type=self.eta_init),
dtype=torch.float))
if self.eta_type in ('scalar', 'pre_vector', 'post_vector', 'matrix'):
init_string += '\n Eta: {} ({}) // '.format(self.eta_type, eta_train_str)
else:
raise ValueError('eta_type: {} not recognized'.format(self.eta_type))
### Lambda dependencies ###
self.lam_train = ml_params.get('lam_train', True)
lam_train_str = 'fixed'
if self.lam_train:
self.params.append('lam')
lam_train_str = 'train'
self.lam_type = ml_params.get('lam_type', 'scalar')
self.lam_init = ml_params.get('lam_init', 'lam_clamp')
# Maximum lambda value/corresponding decay time constant (both always computed)
if 'm_time_scale' in ml_params:
self.m_time_scale = ml_params.get('m_time_scale')
self.lam_clamp = 1. - ml_params['dt'] / self.m_time_scale
else:
self.lam_clamp = ml_params.get('lam_clamp', 0.95)
self.m_time_scale = ml_params['dt'] / (1. - self.lam_clamp)
self.parameter_or_buffer('lam', torch.tensor(
np.abs(self.init_M_parameter(param_type=self.lam_type, init_type=self.lam_init)), # Always positive
dtype=torch.float))
if self.lam_type in ('scalar', 'pre_vector', 'post_vector', 'matrix'):
init_string += 'Lambda: {} ({}) // Lambda_max: {:.2f} (tau: {:.1e})'.format(
self.lam_type, lam_train_str, self.lam_clamp, self.m_time_scale
)
else:
raise ValueError('lam_type: {} not recognized'.format(self.lam_type))
if self.verbose: # Full summary of mp_layer parameters
print(init_string)
def reset_state(self, B=1):
"""
Resets/initializes modulations values
"""
self.M = torch.ones(B, *self.W.shape, device=self.W.device) #shape: (B, n_input, n_output)
self.M = self.M * self.M_init.unsqueeze(0) # (B, n_input, n_output) x (1, n_input, n_output)
self.M_pre = torch.zeros_like(self.M)
@torch.no_grad()
def param_clamp(self):
""" Enforce lambda bounds. Doesn't track gradients, since this is always called after weight updates. """
self.lam.data.clamp_(0., self.lam_clamp)
def init_M_parameter(self, param_type='scalar', init_type='gaussian'):
"""
Initialize different forms of the various M parameters (e.g. eta and lambda).
Default is just one of each parameter for each layer, but can make them
post- and/or presynaptic cell dependent.
Turned into a buffer/parameter externally.
"""
if type(init_type) == float:
if param_type == 'scalar': # Just directly set to init_type
param = init_type
elif param_type == 'pre_vector': # Serves as mean to distribution
param = init_type + rand_weight_init(self.n_input, init_type='guassian', weight_norm=1.0)
elif param_type == 'post_vector':
param = init_type + rand_weight_init(self.n_output, init_type='guassian', weight_norm=1.0)
elif param_type == 'matrix':
param = init_type + rand_weight_init(self.n_input, self.n_output, init_type='guassian', weight_norm=1.0)
elif type(init_type) == str:
if param_type == 'scalar':
if init_type in ('lam_clamp',):
param = self.lam_clamp
elif init_type in ('eta_clamp',):
param = self.eta_clamp
else:
param = rand_weight_init(1, init_type=init_type, weight_norm=1.0)
elif param_type == 'pre_vector':
if init_type in ('lam_clamp',):
param = self.lam_clamp * np.ones((self.n_input,))
elif init_type in ('eta_clamp',):
param = self.eta_clamp * np.ones((self.n_input,))
else:
param = rand_weight_init(self.n_input, init_type=init_type, weight_norm=1.0)
elif param_type == 'post_vector':
if init_type in ('lam_clamp',):
param = self.lam_clamp * np.ones((self.n_output,))
elif init_type in ('eta_clamp',):
param = self.eta_clamp * np.ones((self.n_output,))
else:
param = rand_weight_init(self.n_output, init_type=init_type, weight_norm=1.0)
elif param_type == 'matrix':
if init_type in ('lam_clamp',):
param = self.lam_clamp * np.random.rand(self.n_output, self.n_input,)
elif init_type in ('eta_clamp',):
param = self.eta_clamp * np.random.rand(self.n_output, self.n_input,)
else:
param = rand_weight_init(self.n_input, self.n_output, init_type=init_type, weight_norm=1.0)
else:
raise ValueError('Init of type {} not recognized: {}'.format(type(init_type), init_type))
return param
def build_M_parameter(self, param, param_type='scalar'):
"""
Returns M parameters (e.g. eta and lambda) of the appropriate dimensions
given their type
OUTPUTS:
param_matrix shape is simply something that can be cast to the shape of
M without batch dim: (n_output, n_input), so len(param_expanded.shape)
== 2 always
"""
if param_type == 'scalar':
param_expanded = param.unsqueeze(0) # shape (1, 1)
elif param_type == 'pre_vector':
param_expanded = param.unsqueeze(0) # shape (1, n_input)
elif param_type == 'post_vector':
param_expanded = param.unsqueeze(-1) # shape (n_output, 1)
elif param_type == 'matrix':
param_expanded = param # shape (n_output, n_input)
return param_expanded
def build_M_bounds(self, init_string=''):
"""
Controls maximum and minimum values of modulations. Generally used so
weights don't change signs (since these are often tied to cell type).
Bounds are in order: (upper_vals, lower_vals)
"""
W_fixed = self.W.detach()
if self.mp_type == 'add':
MAX_ADD = self.M_bound_vals[1] # Default: 1.0, Controls how much a weight can be strengthened, 1.0 means the weight's mag can be doubled, 0.0 means it cant be strengthened
MIN_ADD = self.M_bound_vals[0] # Default: 0.0, Any value >0 prevents weight from being fully weakened, e.g. 0.2 means the weight can be weakened to at most 20% of its original value
# Expanation of this expression: (with example values MAX_ADD = 2.0 and MIN_ADD = 0.2)
# First line: Upper bounds on Ms
# For W_ij > 0: Maximum M value is MAX_ADD * W_ij, so 2 * W_ij > 0, meaning positive weights can be strengthened to 3x their initial value
# For W_ij < 0: Maximum M value is -1 * (1 - MIN_ADD) * W_ij = -0.8 * W_ij > 0, since W_ij is negative. Since a positive M_ij would cancel the
# negative W_ij, this means that at most W_ij can be weakened to 0.2 x its original value
# Second line: Lower bounds for M
# For W_ij > 0: Minimum M value is -1 * (1 - MIN_ADD) * W_ij = -0.8 * W_ij < 0, since W_ij is positive, a negative M_ij that saturates this bound
# would reduce W_ij to 0.2 x its original value.
# For W_ij < 0: Minimum M value is MAX_ADD * W_ij = 2 * W_ij < 0, since W_ij is negative. So can strengthen negative weight to 3x its initial value
M_bounds = torch.cat((
(MAX_ADD * W_fixed * (W_fixed > 0) - 1 * (1 - MIN_ADD) * W_fixed * (W_fixed < 0)).unsqueeze(0),
(MAX_ADD * W_fixed * (W_fixed < 0) - 1 * (1 - MIN_ADD) * W_fixed * (W_fixed > 0)).unsqueeze(0)
))
init_string += ' update bounds - Max add: {}, Min add: {}\n'.format(MAX_ADD, MIN_ADD)
elif self.mp_type == 'mult':
max_mult = self.M_bound_vals[1] # Controls how much a weight can be enhanced, 1.0 means the weight's mag can be doubled
min_mult = self.M_bound_vals[0] # Controls how much a weight can be depressed, -1.0 means it can be fully depressed
M_bounds = torch.cat((
max_mult * torch.ones_like(W_fixed).unsqueeze(0),
min_mult * torch.ones_like(W_fixed).unsqueeze(0)
))
init_string += ' update bounds - Max mult: {}, Min mult: {}\n'.format(max_mult, min_mult)
else:
raise ValueError('MP type not recognized in build_M_bounds.')
return M_bounds, init_string
def update_M_matrix(self, pre, post, update_mask=None):
"""
Updates the modulation matrix from one time step to the next.
Should only be called in the network_step pass once. Directly updates self.M.
Note that this is NOT automatically called in MP layer's "forward" call, since the
postsynaptic activity could be dependent on other factors (e.g. other layers).
M updates can be frozen from the update_mask, if a given bactch idx is
False (e.g. because it is beyond the end of the sequence)
pre.shape: (B, n_input)
post.shape: (B, n_output)
update_mask: (B,)
"""
eta = self.build_M_parameter(self.eta, self.eta_type)
lam = self.build_M_parameter(self.lam, self.lam_type)
M = self.M
delta_M = torch.zeros_like(M)
if update_mask is not None: # Update each batch_idx individually using the update_mask
for batch_idx in range(M.shape[0]):
if update_mask[batch_idx]: # Only calculates delta_M if batch is being updated (I think this saves time?)
if self.m_update_type in ('hebb_pre',):
post = 1 / math.sqrt(post.shape[-1]) * torch.ones_like(post)
if self.m_update_type in ('hebb_assoc', 'hebb_pre',):
delta_M[batch_idx] = - M[batch_idx] + lam * M[batch_idx] + eta * torch.einsum(
'i, I -> iI', post[batch_idx], pre[batch_idx]
)
elif self.m_update_type in ('oja',):
delta_M[batch_idx] = (eta * torch.einsum('i, I -> iI', post[batch_idx], pre[batch_idx]) -
torch.abs(eta) * torch.einsum('i, iI -> iI', post[batch_idx]**2, M[batch_idx]))
else: # Update all M at once
if self.m_update_type in ('hebb_pre',):
post = 1 / math.sqrt(post.shape[-1]) * torch.ones_like(post)
if self.m_update_type in ('hebb_assoc', 'hebb_pre',):
delta_M = - M + lam.unsqueeze(0) * M + eta.unsqueeze(0) * torch.einsum(
'Bi, BI -> BiI', post, pre
)
elif self.m_update_type in ('oja',):
raise NotImplementedError('Need to update to a batched version.')
# delta_M[batch_idx] = (eta * torch.einsum('i, I -> iI', post[batch_idx], pre[batch_idx]) -
# torch.abs(eta) * torch.einsum('i, iI -> iI', post[batch_idx]**2, M[batch_idx]))
self.M_pre = self.M + delta_M
self.M = self.m_act_fn(self.M_pre)
# # Masks batches of delta_M
# delta_M_masked = torch.einsum('B, BiI -> BiI', update_mask, delta_M)
# Update M matrices, while being sure update holds matrix within bounds
# (this may error if not self.ei_types, but this is always true in our settings)
# (note: updates to restristed cell types is built into the eta matrix)
if self.modulation_bounds:
self.M = torch.clamp(self.M, min=self.M_bounds[1], max=self.M_bounds[0])
return delta_M # This is only used for theory matching
def get_modulated_weights(self, M=None):
"""
Returns modulated weights, taking into account exactly how M and W are combined.
Note the batch size of the modulated weights is set by the batch size of M.
If M is not None, the passed M matrix is used, otherwise just uses stored Ms (this
is be used for analysis of the network after training)
"""
W = self.W
if M is None:
M = self.M
# Fixed weights and M matrix, either multiplicative or additive
if self.mp_type == 'mult':
modulated_weights = W.unsqueeze(0) + W.unsqueeze(0) * M
elif self.mp_type == 'add':
modulated_weights = W.unsqueeze(0) + M
return modulated_weights
def forward(self, x, run_mode='minimal'):
"""
Passes inputs through the modulated weights. Activation are handled
externally to allow for more complicated architectures.
Updating of modulations is handled externally
INPUTS:
x.shape: [B, n_input]
INTERNALS:
b.shape: [n_output]
modulated_weights.shape: [B, n_output, n_input],
OUTPUTS:
pre_act.shape: [B, n_output]
"""
modulated_weights = self.get_modulated_weights()
pre_act_no_bias = torch.einsum('BiI, BI -> Bi', modulated_weights, x)
pre_act = pre_act_no_bias + self.b.unsqueeze(0)
if run_mode in ('track_states',):
db = {
'pre_act_no_bias': pre_act_no_bias.detach(),
'M': self.M.detach(),
'b': self.b.detach().unsqueeze(0), # record bias information as well
}
if self.m_act:
db['M_pre'] = self.M_pre.detach()
else:
db = None
return pre_act, db
class MultiPlasticNetBase(BaseNetwork):
"""
Base network for the multiplastic network. Initializes things like the output
layers and activation functions that are mostly shared across all types of
MPNs, no matter the connections that lead from input to output.
"""
def __init__(self, net_params, n_output_pre, output_matrix="", verbose=False):
# Note that this assumes self.output has already been set in child
assert hasattr(self, 'n_output')
super().__init__(net_params, verbose=verbose)
init_string = 'MultiPlastic Net:\n'
init_string += ' output neurons: {}\n'.format(
self.n_output
)
# Get list that should be parameters of the network (for later matching to experiment)
self.params = ['W_output',] # Biases can be added below if used.
# Numpy equivalents only used for debugging purposes
self.act = net_params.get('activation', 'linear')
self.act_fn, self.act_fn_np, self.act_fn_p = get_activation_function(self.act)
init_string += ' Act: {}\n'.format(
self.act,
)
self.b_output_active = net_params.get('output_bias', True)
if self.b_output_active:
self.params.append('b_output')
self.b_output_init = 'gaussian'
else:
self.b_output_init = 'zeros'
# By default, these are set to none but are overridden in the initialization if cell types are used
self.ei_balance = None
self.input_cell_types = None
self.hidden_cell_types = None
self.cell_types = net_params.get('cell_types', None)
if self.cell_types:
raise NotImplementedError()
# Output weights
self.W_output_init = net_params.get('W_output_init', 'xavier')
self.parameter_or_buffer('W_output', torch.tensor(
rand_weight_init(n_output_pre, self.n_output, init_type=self.W_output_init,
cell_types=self.hidden_cell_types),
dtype=torch.float)
)
# overwrite
if output_matrix == "":
pass
elif output_matrix == "untrained":
print("Output Matrix Untrained")
self.W_output.requires_grad = False
elif output_matrix == "orthogonal":
print("Output Matrix Orthogonal and Untrained")
W_output_init = torch.empty(self.n_output, n_output_pre) # Create a matrix of size (n_output, n_output_pre)
orthogonal_(W_output_init) # In-place orthogonal initialization
W_output_init = W_output_init.T
self.parameter_or_buffer('W_output', torch.tensor(W_output_init, dtype=torch.float))
self.W_output.requires_grad = False
else:
raise Exception("Output Matrix not recognized")
self.parameter_or_buffer('b_output', torch.tensor(
rand_weight_init(self.n_output, init_type=self.b_output_init),
dtype=torch.float)
)
if verbose: # Full summary of readout parameters (MP layer prints out internally)
print(init_string)
def reset_state(self, B=1):
""" Resets states of all internal layer M matrices """
for mp_layer in self.mp_layers:
mp_layer.reset_state(B=B)
def param_clamp(self):
# mp_layer call doesn't track gradients, since this is always called after weight updates
for mp_layer in self.mp_layers:
mp_layer.param_clamp()
@torch.no_grad()
def _monitor_init(self, train_params, train_data, train_trails=None, valid_batch=None, valid_trails=None):
# Additional quantities to track during training, note initializes these first so that _monitor call
# inside super()._monitor_init can append additional quantities.
if self.hist is None:
self.hist = {
'iter': 0,
}
for mpl_idx, mp_layer in enumerate(self.mp_layers):
self.hist['eta{}'.format(mp_layer.mp_layer_name)] = []
self.hist['lam{}'.format(mp_layer.mp_layer_name)] = []
super()._monitor_init(train_params, train_data, train_trails=train_trails, valid_batch=valid_batch, valid_trails=valid_trails)
@torch.no_grad()
def _monitor(self, train_batch, train_go_info_batch, valid_go_info_batch, train_type='supervised', output=None, loss=None, loss_components=None,
acc=None, valid_batch=None, nowiter=None):
super()._monitor(train_batch, train_go_info_batch, valid_go_info_batch, output=output, loss=loss, loss_components=loss_components,
valid_batch=valid_batch, nowiter=nowiter)
for mpl_idx, mp_layer in enumerate(self.mp_layers):
self.hist['eta{}'.format(mp_layer.mp_layer_name)].append(
mp_layer.eta.detach().cpu().numpy()
)
self.hist['lam{}'.format(mp_layer.mp_layer_name)].append(
mp_layer.lam.detach().cpu().numpy()
)
class MultiPlasticNet(MultiPlasticNetBase):
"""
Two-layer feedforward setup, with single multi-plastic layer followed by a readout layer.
"""
def __init__(self, net_params, verbose=False):
if 'n_neurons' in net_params:
# assert len(net_params['n_neurons']) == 3
self.n_input = net_params['n_neurons'][0]
self.n_hidden = net_params['n_neurons'][1]
self.n_output = net_params['n_neurons'][2]
else:
self.n_input = net_params['n_input']
self.n_hidden = net_params['n_hidden']
self.n_output = net_params['n_output']
super().__init__(net_params, self.n_hidden, verbose=verbose)
# Creates the input MP layer
self.param_clamping = True # Always have param clamping for MP layers because lam bounds
net_params['ml_params']['n_input'] = self.n_input
net_params['ml_params']['n_output'] = self.n_hidden
net_params['ml_params']['dt'] = self.dt
self.mp_layer = MultiPlasticLayer(net_params['ml_params'], output_matrix=self.output_matrix, verbose=verbose)
self.params.extend(self.mp_layer.params)
self.mp_layers = [self.mp_layer,] # List of all mp_layers in this network
def forward(self, inputs, run_mode='minimal', verbose=False):
x = torch.clone(inputs)
# Returns pre-activation
hidden_pre, db_mp = self.mp_layer(x, run_mode=run_mode)
hidden = self.act_fn(hidden_pre)
output_hidden = torch.einsum('iI, BI -> Bi', self.W_output, hidden)
output = output_hidden + self.b_output.unsqueeze(0)
if run_mode in ('track_states'):
db = {
'M': db_mp['M'],
'hidden_pre': hidden_pre.detach(),
'hidden': hidden.detach(),
"input": x.detach(),
}
else:
db = None
return output, hidden, db
def network_step(self, current_input, run_mode='minimal', verbose=False):
"""
Performs a single batch pass forward for the network. This mostly consists of a forward pass and
the associated updates to internal states (i.e. the modulations)
This should not be passed a full sequence of data, only data from a given time point
"""
assert len(current_input.shape) == 2
output, current_hidden, db = self.forward(current_input, run_mode=run_mode, verbose=verbose)
# M updated internally when this is called, M here is only used if finding fixed points (not yet implemented)
M = self.mp_layer.update_M_matrix(current_input, current_hidden)
return output, db
class DeepMultiPlasticNet(MultiPlasticNetBase):
"""
N-layer feedforward setup, with N-1 multi-plastic layers followed by a single readout layer.
"""
def __init__(self, net_params, verbose=False, forzihan=True):
cfg = copy.deepcopy(net_params)
# Mar 16th: add input layer
self.input_layer_active = cfg.get('input_layer_add', False)
self.input_layer_active_trainable = cfg.get('input_layer_add_trainable', False)
# 2025-10-29: recurrent augmentation
self.recurrent_active = cfg.get('recurrent_layer_add', False)
arch = list(cfg['n_neurons'])
if self.input_layer_active:
arch.insert(1, cfg.get('linear_embed', 128)) # add an initial linear embedding layer
n_layers = len(arch) - 1
self.n_input = cfg['n_neurons'][0]
self.n_hidden = cfg['n_neurons'][1]
self.n_output = cfg['n_neurons'][-1]
self.output_matrix = cfg['output_matrix']
super().__init__(cfg, cfg['n_neurons'][-2], output_matrix=self.output_matrix, verbose=verbose)
# Creates all the MP layers
self.param_clamping = True # Always have param clamping for MP layers because lam bounds
input_init_type = net_params.get("input_init_type", "xavier")
if self.input_layer_active:
self.W_initial_linear = nn.Linear(arch[0], arch[1])
# self.W_initial_linear.weight.data = torch.tensor(
# rand_weight_init(arch[0], arch[1], init_type=net_params.get('W_init', 'xavier')),
# dtype=torch.float
# )
with torch.no_grad():
if input_init_type == "identity":
W = self.W_initial_linear.weight
W.zero_()
k = min(W.size(0), W.size(1))
W[:k, :k].copy_(torch.eye(k, device=W.device, dtype=W.dtype))
elif input_init_type == "identity_noise":
eps = net_params.get("identity_eps", 1e-3)
W = self.W_initial_linear.weight
W.zero_()
k = min(W.size(0), W.size(1))
W[:k, :k].copy_(torch.eye(k, device=W.device, dtype=W.dtype))
W.add_(eps * torch.randn_like(W))
elif input_init_type == "orthogonal":
torch.nn.init.orthogonal_(self.W_initial_linear.weight, gain=1.0)
else:
self.W_initial_linear.weight.copy_(torch.tensor(
rand_weight_init(arch[0], arch[1], init_type=input_init_type),
device=self.W_initial_linear.weight.device,
dtype=self.W_initial_linear.weight.dtype
))
if not self.input_layer_active_trainable:
print(f' Input Layer Frozen.')
self.W_initial_linear.weight.requires_grad = False
if net_params.get('input_layer_bias', False):
self.W_initial_linear.bias.data = torch.tensor(
rand_weight_init(arch[1], init_type='gaussian'),
dtype=torch.float
)
else:
self.W_initial_linear.bias = None
self.mp_layers = []
self.h_recurrent = []
self.recurrent_weights = nn.ParameterList()
print(self.dt)
start_layer_count = 1 if self.input_layer_active else 0
# if additional input layer is added, shift the starting index of layer counting from 1
for mpl_idx in range(start_layer_count, n_layers - 1): # (e.g. three-layer has two MP layers)
if forzihan:
assert n_layers - 1 - start_layer_count == 1, "2025-10-29: One-Layer MPN Now"
ml_key = f'ml_params{mpl_idx}' if f'ml_params{mpl_idx}' in cfg else 'ml_params'
# Updates some parameters for each new MPL
cfg[ml_key]['dt'] = self.dt
cfg[ml_key]['mpl_name'] = str(mpl_idx)
cfg[ml_key]['n_input'] = arch[mpl_idx]
cfg[ml_key]['n_output'] = arch[mpl_idx + 1]
setattr(
self,
f'mp_layer{mpl_idx}',
MultiPlasticLayer(cfg[ml_key], output_matrix=self.output_matrix, verbose=verbose)
)
self.mp_layers.append(getattr(self, 'mp_layer{}'.format(mpl_idx)))
self.params.extend([param+str(mpl_idx) for param in self.mp_layers[-1].params])
# 2025-10-29: add hidden activity with recurrency
if self.recurrent_active:
n_units = arch[mpl_idx + 1] # layer output size
# we'll use a simple scaled orthogonal / xavier-ish init
W_rec_init = rand_weight_init(n_units, n_units, init_type=net_params.get('W_rec_init', 'xavier'))
W_rec_param = nn.Parameter(torch.tensor(W_rec_init, dtype=torch.float), requires_grad=True)
self.recurrent_weights.append(W_rec_param)
if not self.recurrent_active:
print(f' No Hidden Recurrency.')
else:
print(f' Hidden Recurrency is Added.')
def forward(self, inputs, run_mode='minimal', verbose=False):
"""
"""
# 2025-11-19: the x_t in the MPN paper is the "input to MPN layer"
# namely after the MLP
if self.input_layer_active:
x = self.W_initial_linear(inputs)
x = self.act_fn(x)
else:
x = inputs
layer_input = torch.clone(x)
mpl_activities = [x,] # Used for updating the M matrices
db = {} if run_mode in ('track_states',) else None
for mpl_idx, mp_layer in enumerate(self.mp_layers):
# Returns pre-activation
layer_input_old = layer_input.clone()
hidden_pre, db_mp = mp_layer(layer_input, run_mode=run_mode)
if self.recurrent_active:
if len(self.h_recurrent) == 0: # set zero initial hidden activity
self.h_recurrent.append(torch.zeros_like(hidden_pre))
h_prev = self.h_recurrent[-1]
W_rec = self.recurrent_weights[mpl_idx]
hidden_pre = hidden_pre + torch.matmul(h_prev, W_rec.t())
layer_input = self.act_fn(hidden_pre)
# save the t-1 hidden state for the use at time t
if self.recurrent_active:
self.h_recurrent[-1] = layer_input
if run_mode in ('debug',):
print(f' MP Layer {mpl_idx} forward.')
print(' Pre-act mean {:.2e} Post-act mean {:.2e}'.format(
torch.mean(hidden_pre.detach()), torch.mean(layer_input.detach())
))
mpl_activities.append(layer_input) # Postsyn activity
if run_mode in ('track_states',):
db['hidden_pre{}'.format(mp_layer.mp_layer_name)] = hidden_pre.detach()
db['hidden{}'.format(mp_layer.mp_layer_name)] = layer_input.detach()
db['M{}'.format(mp_layer.mp_layer_name)] = db_mp['M']
db['b{}'.format(mp_layer.mp_layer_name)] = db_mp['b']
db['input{}'.format(mp_layer.mp_layer_name)] = layer_input_old.detach()
if run_mode in ('debug',):
print(f' Output layer forward.')
output_hidden = torch.einsum('iI, BI -> Bi', self.W_output, layer_input)
output = output_hidden + self.b_output.unsqueeze(0)
return output, mpl_activities, db
def network_step(self, current_input, run_mode='minimal', verbose=False, seq_idx=None):
"""
Performs a single batch pass forward for the network. This mostly consists of a forward pass and
the associated updates to internal states (i.e. the modulations)
This should not be passed a full sequence of data, only data from a given time point
"""
assert len(current_input.shape) == 2
if run_mode in ('debug',):
print(f' Network step:')
# at beginning of sequence (seq_idx == 0), provide empty list
if seq_idx == 0 and self.recurrent_active:
self.h_recurrent = []
# current_input is per-time input, so has shape (batch_size, input_size)
output, mpl_activities, db = self.forward(current_input, run_mode=run_mode, verbose=verbose)
# M updated internally when this is called
for mpl_idx, mp_layer in enumerate(self.mp_layers):
if run_mode in ('debug',):
print(f' MP Layer {mpl_idx} M update.')
print(' Pre mean {:.2e} post mean {:.2e}'.format(
torch.mean(mpl_activities[mpl_idx].detach()), torch.mean(mpl_activities[mpl_idx + 1].detach())
))
print(' M mag mean {:.2e} max {:.2e}'.format(
torch.mean(torch.abs(mp_layer.M.detach())), torch.max(torch.abs(mp_layer.M.detach()))
))
_ = mp_layer.update_M_matrix(mpl_activities[mpl_idx], mpl_activities[mpl_idx + 1])
return output, mpl_activities, db