@@ -22,6 +22,7 @@ class SpNode_Enc(Node):
22
22
| * Sz - the current spike values (binary vector signal) at time t
23
23
| * Trace_z - filtered trace values of the spike values (real-valued vector)
24
24
| * mask - a binary mask to be applied to the neural activities
25
+ | * t_spike - time of last spike (per neuron inside this node)
25
26
26
27
Args:
27
28
name: the name/label of this node
@@ -32,16 +33,26 @@ class SpNode_Enc(Node):
32
33
33
34
batch_size: batch-size this node should assume (for use with static graph optimization)
34
35
36
+ integrate_kernel: Dict defining the neural state integration process type. The expected keys and
37
+ corresponding value types are specified below:
38
+
39
+ :`'integrate_type'`: <UNUSED>
40
+
41
+ :`'dt'`: type integration time constant for the spiking neurons
42
+
43
+ spike_kernel: <UNUSED>
44
+
35
45
trace_kernel: Dict defining the signal tracing process type. The expected keys and
36
46
corresponding value types are specified below:
37
47
38
48
:`'dt'`: type integration time constant for the trace
39
49
40
- :`'tau '`: the filter time constant for the trace
50
+ :`'tau_trace '`: the filter time constant for the trace
41
51
42
52
:Note: specifying None will automatically set this node to not use variable tracing
43
53
"""
44
- def __init__ (self , name , dim , gain = 1.0 , batch_size = 1 , trace_kernel = None ):
54
+ def __init__ (self , name , dim , gain = 1.0 , batch_size = 1 , integrate_cfg = None ,
55
+ spike_kernel = None , trace_kernel = None ):
45
56
node_type = "spike_enc_state"
46
57
super ().__init__ (node_type , name , dim )
47
58
self .dim = dim
@@ -50,28 +61,29 @@ def __init__(self, name, dim, gain=1.0, batch_size=1, trace_kernel=None):
50
61
51
62
self .gain = gain
52
63
self .dt = 1.0 # integration time constant (ms)
64
+ self .tau_m = 1.0
65
+ self .integrate_cfg = integrate_cfg
66
+ if self .integrate_cfg is not None :
67
+ self .dt = self .integrate_cfg .get ("dt" )
53
68
54
69
self .trace_kernel = trace_kernel
55
70
self .trace_dt = 1.0
71
+ self .tau_trace = 20.0
56
72
if self .trace_kernel is not None :
57
- # trace integration time constant (ms)
58
- self .trace_dt = self .trace_kernel .get ("dt" )
59
- # filter time constant
60
- self .tau = self .trace_kernel .get ("tau" )
61
-
62
- # derived settings that are a function of other spiking neuron settings
63
- self .a = np .exp (- self .trace_dt / self .tau )
64
- self .tau_j = 1.0
73
+ if self .trace_kernel .get ("dt" ) is not None :
74
+ self .trace_dt = self .trace_kernel .get ("dt" ) # trace integration time constant (ms)
75
+ #5.0 # filter time constant -- where dt (or T) = 0.001 (to model ms)
76
+ self .tau_trace = self .trace_kernel .get ("tau_trace" )
65
77
66
78
# set LIF spiking neuron-specific (vector/scalar) constants
67
- self .constant_name = ["gain" , "dt" , "trace_alpha " ]
79
+ self .constant_name = ["gain" , "dt" , "tau_trace " ]
68
80
self .constants = {}
69
81
self .constants ["dt" ] = self .dt
70
82
self .constants ["gain" ] = self .gain
71
- self .constants ["trace_alpha " ] = self .a
83
+ self .constants ["tau_trace " ] = self .tau_trace
72
84
73
85
# set LIF spiking neuron-specific vector statistics
74
- self .compartment_names = ["z" , "Sz" , "Trace_z" ] # , "x_tar ", "Ns "]
86
+ self .compartment_names = ["z" , "Sz" , "Trace_z" , "t_spike " , "ref " ]
75
87
self .compartments = {}
76
88
for name in self .compartment_names :
77
89
self .compartments [name ] = tf .Variable (tf .zeros ([batch_size ,dim ]),
@@ -82,8 +94,6 @@ def __init__(self, name, dim, gain=1.0, batch_size=1, trace_kernel=None):
82
94
self .masks [name ] = tf .Variable (tf .ones ([batch_size ,dim ]),
83
95
name = "{}_{}" .format (self .name , name ))
84
96
85
- self .connected_cables = []
86
-
87
97
def compile (self ):
88
98
info = super ().compile ()
89
99
#info["leak"] = self.leak
@@ -98,41 +108,48 @@ def step(self, injection_table=None, skip_core_calc=False):
98
108
bmask = self .masks .get ("mask" )
99
109
########################################################################
100
110
if skip_core_calc == False :
111
+ # compute spike response model
112
+ dt = self .constants .get ("dt" ) # integration time constant
101
113
z = self .compartments .get ("z" )
114
+ self .t = self .t + dt # advance time forward by dt (t <- t + dt)
102
115
#Sz = transform.convert_to_spikes(z, self.max_spike_rate, self.dt)
103
116
Sz = stat .convert_to_spikes (z , gain = self .gain )
104
117
105
118
if injection_table .get ("Sz" ) is None :
106
119
if self .do_inplace == True :
120
+ t_spike = self .compartments .get ("t_spike" )
121
+ t_spike = t_spike * (1.0 - Sz ) + (Sz * self .t )
107
122
self .compartments ["Sz" ].assign (Sz )
123
+ self .compartments ["t_spike" ].assign (t_spike )
108
124
else :
125
+ t_spike = self .compartments .get ("t_spike" )
126
+ t_spike = t_spike * (1.0 - Sz ) + (Sz * self .t )
109
127
self .compartments ["Sz" ] = Sz
110
- ##########################################################################
111
-
112
- ##########################################################################
113
- trace_alpha = self .constants .get ("trace_alpha" )
114
- trace_z_tm1 = self .compartments .get ("Trace_z" )
115
- # apply variable trace filters z_l(t) = (alpha * z_l(t))*(1−s`(t)) +s_l(t)
116
- trace_z = tf .add ((trace_z_tm1 * trace_alpha ) * (- Sz + 1.0 ), Sz )
128
+ self .compartments ["t_spike" ] = t_spike
129
+
130
+ #### update trace variable ####
131
+ tau_tr = self .constants .get ("tau_trace" )
132
+ #Sz = self.compartments.get("Sz")
133
+ tr_z = self .compartments .get ("Trace_z" )
134
+ d_tr = - tr_z / tau_tr + Sz
135
+ tr_z = tr_z + d_tr
117
136
if injection_table .get ("Trace_z" ) is None :
118
137
if self .do_inplace == True :
119
- self .compartments ["Trace_z" ].assign (trace_z )
138
+ self .compartments ["Trace_z" ].assign (tr_z )
120
139
else :
121
- self .compartments ["Trace_z" ] = trace_z
122
- # Ns = self.compartments.get("Ns")
123
- # x_tar = self.compartments.get("x_tar")
124
- # x_tar = x_tar + (trace_z - x_tar)/Ns
125
- # if injection_table.get("x_tar") is None:
126
- # if self.do_inplace == True:
127
- # self.compartments["x_tar"].assign(x_tar)
128
- # else:
129
- # self.compartments["x_tar"] = x_tar
130
- # Ns = Ns + 1
131
- # if injection_table.get("Ns") is None:
140
+ self .compartments ["Trace_z" ] = tr_z
141
+ ##########################################################################
142
+
143
+ ##########################################################################
144
+ # trace_alpha = self.constants.get("trace_alpha")
145
+ # trace_z_tm1 = self.compartments.get("Trace_z")
146
+ # # apply variable trace filters z_l(t) = (alpha * z_l(t))*(1−s`(t)) +s_l(t)
147
+ # trace_z = tf.add((trace_z_tm1 * trace_alpha) * (-Sz + 1.0), Sz)
148
+ # if injection_table.get("Trace_z") is None:
132
149
# if self.do_inplace == True:
133
- # self.compartments["Ns "].assign(Ns )
150
+ # self.compartments["Trace_z "].assign(trace_z )
134
151
# else:
135
- # self.compartments["Ns "] = Ns
152
+ # self.compartments["Trace_z "] = trace_z
136
153
137
154
if bmask is not None : # applies mask to all component variables of this node
138
155
for key in self .compartments :
@@ -143,8 +160,6 @@ def step(self, injection_table=None, skip_core_calc=False):
143
160
self .compartments [key ] = ( self .compartments .get (key ) * bmask )
144
161
145
162
########################################################################
146
- if skip_core_calc == False :
147
- self .t = self .t + 1
148
163
149
164
# a node returns a list of its named component values
150
165
values = []
0 commit comments