-
Notifications
You must be signed in to change notification settings - Fork 1
/
learning.py
136 lines (117 loc) · 5.19 KB
/
learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import Union, Optional, Sequence
import torch
from bindsnet.network.topology import (
AbstractConnection,
Connection,
LocalConnection,
)
from locally_connected_multi_chan import LocalConnection2D
from bindsnet.learning import LearningRule
class PostPre(LearningRule):
"""
Simple STDP rule involving both pre- and post-synaptic spiking activity. By default,
pre-synaptic update is negative and the post-synaptic update is positive. This BindsNet class is modified
to be adaptable with multi-channel local connection
"""
def __init__(
self,
connection: AbstractConnection,
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs,
) -> None:
"""
Constructor for ``PostPre`` learning rule.
:param connection: An ``AbstractConnection`` object whose weights the
``PostPre`` learning rule will modify.
:param nu: Single or pair of learning rates for pre- and post-synaptic events.
:param reduction: Method for reducing parameter updates along the batch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
"""
super().__init__(
connection=connection,
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs,
)
assert (
self.source.traces and self.target.traces
), "Both pre- and post-synaptic nodes must record spike traces."
if isinstance(connection, (Connection, LocalConnection)):
self.update = self._connection_update
elif isinstance(connection, (LocalConnection2D)):
self.update = self._local_connection_update
else:
raise NotImplementedError(
"This learning rule is not supported for this Connection type."
)
def _connection_update(self, **kwargs) -> None:
"""
Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection``
class.
"""
batch_size = self.source.batch_size
# Pre-synaptic update.
if self.nu[0]:
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float().to(self.connection.w.device)
target_x = (self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]).to(self.connection.w.device)
self.connection.w -= (self.reduction(torch.bmm(source_s, target_x), dim=0))
del source_s, target_x
# Post-synaptic update.
if self.nu[1]:
target_s = (
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
).to(self.connection.w.device)
source_x = self.source.x.view(batch_size, -1).unsqueeze(2).to(self.connection.w.device)
self.connection.w += (self.reduction(torch.bmm(source_x, target_s), dim=0))
del source_x, target_s
super().update()
def _local_connection_update(self, **kwargs) -> None:
"""
Post-pre learning rule for ``LocalConnection2D`` subclass of
``AbstractConnection`` class.
"""
# Get LC layer parameters.
padding, stride = self.connection.padding, self.connection.stride
batch_size = self.source.batch_size
kernel_width = self.connection.kernel_size[0]
kernel_height = self.connection.kernel_size[1]
in_channels = self.connection.in_channels
out_channels = self.connection.out_channels
width_out = self.connection.conv_size[0]
height_out = self.connection.conv_size[1]
## target_x (s) ch_o, w_o, h_o
target_x = self.target.x.reshape(batch_size, out_channels * width_out * height_out, 1)
target_x = target_x * torch.eye(out_channels * width_out * height_out).to(self.connection.w.device)
source_s = self.source.s.type(torch.float).unfold(-2, kernel_width,stride[0]).unfold(-2, kernel_height, stride[1]).reshape(
batch_size,
width_out * height_out,
in_channels * kernel_width * kernel_height,
).repeat(
1,
out_channels,
1,
)
target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * width_out*height_out,1)
target_s = target_s * torch.eye(out_channels * width_out * height_out).to(self.connection.w.device)
source_x = self.source.x.unfold(-2, kernel_width,stride[0]).unfold(-2, kernel_height, stride[1]).reshape(
batch_size,
width_out * height_out,
in_channels * kernel_width * kernel_height,
).repeat(
1,
out_channels,
1,
)
# Pre-synaptic update.
if self.nu[0]:
pre = self.reduction(torch.bmm(target_x,source_s), dim=0)
self.connection.w -= self.nu[0] * pre.view(self.connection.w.size())
# Post-synaptic update.
if self.nu[1]:
post = self.reduction(torch.bmm(target_s, source_x),dim=0)
self.connection.w += self.nu[1] * post.view(self.connection.w.size())
super().update()