Skip to content

Commit f30b840

Browse files
committed
cleaned up lif node class and updates/tweaks
1 parent a47e69f commit f30b840

File tree

7 files changed

+176
-182
lines changed

7 files changed

+176
-182
lines changed

docs/walkthroughs/demo7_spiking.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ tau_mem = 20 # membrane potential time constant
320320
V_thr = 0.4 # spiking threshold
321321
# Default for rec_T of 1 ms will be used - this is the default for SpNode_LIF(s)
322322
integrate_cfg = {"integrate_type" : "euler", "dt" : dt}
323-
spike_kernel = {"V_thr" : V_thr, "tau_mem" : tau_mem}
323+
spike_kernel = {"V_thr" : V_thr, "tau_m" : tau_mem}
324324
trace_kernel = {"dt" : dt, "tau" : 5.0}
325325

326326
# set up system -- notice for z2, a gain of 0.25 yields spike frequency of 63.75 Hz
@@ -354,7 +354,7 @@ this sets up an SNN structure with three layers -- an input layer `z2` containin
354354
the Poisson spike train nodes (which will be driven by input data `x`), an internal
355355
layer of LIF nodes, and an output layer of LIF nodes. We have also opted to
356356
simplify the choice of meta-parameters and directly set the membrane potential
357-
constant `tau_mem` directly (instead of messing with membrane resistance and capacitance).
357+
constant `tau_m` directly (instead of messing with membrane resistance and capacitance).
358358
Nothing else is out of the ordinary in creating an `NGCGraph` except that we have
359359
also included a simple specialized convenience node `d1`, which will serve as a special part
360360
of our SNN structure that will naturally give us an easy way to adapt this SNN's

ngclearn/engine/ngc_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def evolve(self, clamped_vars=None, readout_vars=None, init_vars=None,
609609
)
610610
return readouts
611611

612-
def clear(self):
612+
def clear(self, batch_size=-1):
613613
"""
614614
Clears/deletes any persistent signals currently embedded w/in this graph's Nodes
615615
"""
@@ -618,4 +618,4 @@ def clear(self):
618618
self.injection_table = {}
619619
for node_name in self.nodes:
620620
node = self.nodes.get(node_name)
621-
node.clear()
621+
node.clear(batch_size=batch_size)

ngclearn/engine/nodes/node.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def calc_update(self, update_radius=-1.0):
257257
"""
258258
return []
259259

260-
def clear(self):
260+
def clear(self, batch_size=-1):
261261
""" Wipes/clears values of each compartment in this node (and sets .is_clamped = False). """
262262
#print("CLEAR for {} w/ ip = {}".format(self.name, self.do_inplace))
263263
#tf.print("=============== CLEAR ===============")
@@ -269,14 +269,20 @@ def clear(self):
269269
if self.do_inplace == True:
270270
self.compartments[comp_name].assign(comp_value * 0)
271271
else:
272-
self.compartments[comp_name] = (comp_value * 0)
272+
if batch_size > 0:
273+
self.compartments[comp_name] = tf.zeros([batch_size, comp_value.shape[1]])
274+
else:
275+
self.compartments[comp_name] = (comp_value * 0)
273276
for mask_name in self.mask_names:
274277
mask_value = self.masks.get(mask_name)
275278
if mask_value is not None:
276279
if self.do_inplace == True:
277280
self.masks[mask_name].assign(mask_value * 0 + 1)
278281
else:
279-
self.masks[mask_name] = (mask_value * 0 + 1)
282+
if batch_size > 0:
283+
self.masks[mask_name] = tf.ones([batch_size, mask_value.shape[1]])
284+
else:
285+
self.masks[mask_name] = (mask_value * 0 + 1)
280286

281287
def set_cold_state(self, injection_table=None, batch_size=-1):
282288
"""

ngclearn/engine/nodes/spnode_enc.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class SpNode_Enc(Node):
2222
| * Sz - the current spike values (binary vector signal) at time t
2323
| * Trace_z - filtered trace values of the spike values (real-valued vector)
2424
| * mask - a binary mask to be applied to the neural activities
25+
| * t_spike - time of last spike (per neuron inside this node)
2526
2627
Args:
2728
name: the name/label of this node
@@ -32,16 +33,26 @@ class SpNode_Enc(Node):
3233
3334
batch_size: batch-size this node should assume (for use with static graph optimization)
3435
36+
integrate_kernel: Dict defining the neural state integration process type. The expected keys and
37+
corresponding value types are specified below:
38+
39+
:`'integrate_type'`: <UNUSED>
40+
41+
:`'dt'`: type integration time constant for the spiking neurons
42+
43+
spike_kernel: <UNUSED>
44+
3545
trace_kernel: Dict defining the signal tracing process type. The expected keys and
3646
corresponding value types are specified below:
3747
3848
:`'dt'`: type integration time constant for the trace
3949
40-
:`'tau'`: the filter time constant for the trace
50+
:`'tau_trace'`: the filter time constant for the trace
4151
4252
:Note: specifying None will automatically set this node to not use variable tracing
4353
"""
44-
def __init__(self, name, dim, gain=1.0, batch_size=1, trace_kernel=None):
54+
def __init__(self, name, dim, gain=1.0, batch_size=1, integrate_cfg=None,
55+
spike_kernel=None, trace_kernel=None):
4556
node_type = "spike_enc_state"
4657
super().__init__(node_type, name, dim)
4758
self.dim = dim
@@ -50,28 +61,29 @@ def __init__(self, name, dim, gain=1.0, batch_size=1, trace_kernel=None):
5061

5162
self.gain = gain
5263
self.dt = 1.0 # integration time constant (ms)
64+
self.tau_m = 1.0
65+
self.integrate_cfg = integrate_cfg
66+
if self.integrate_cfg is not None:
67+
self.dt = self.integrate_cfg.get("dt")
5368

5469
self.trace_kernel = trace_kernel
5570
self.trace_dt = 1.0
71+
self.tau_trace = 20.0
5672
if self.trace_kernel is not None:
57-
# trace integration time constant (ms)
58-
self.trace_dt = self.trace_kernel.get("dt")
59-
# filter time constant
60-
self.tau = self.trace_kernel.get("tau")
61-
62-
# derived settings that are a function of other spiking neuron settings
63-
self.a = np.exp(-self.trace_dt/self.tau)
64-
self.tau_j = 1.0
73+
if self.trace_kernel.get("dt") is not None:
74+
self.trace_dt = self.trace_kernel.get("dt") # trace integration time constant (ms)
75+
#5.0 # filter time constant -- where dt (or T) = 0.001 (to model ms)
76+
self.tau_trace = self.trace_kernel.get("tau_trace")
6577

6678
# set LIF spiking neuron-specific (vector/scalar) constants
67-
self.constant_name = ["gain", "dt", "trace_alpha"]
79+
self.constant_name = ["gain", "dt", "tau_trace"]
6880
self.constants = {}
6981
self.constants["dt"] = self.dt
7082
self.constants["gain"] = self.gain
71-
self.constants["trace_alpha"] = self.a
83+
self.constants["tau_trace"] = self.tau_trace
7284

7385
# set LIF spiking neuron-specific vector statistics
74-
self.compartment_names = ["z", "Sz", "Trace_z"] #, "x_tar", "Ns"]
86+
self.compartment_names = ["z", "Sz", "Trace_z", "t_spike", "ref"]
7587
self.compartments = {}
7688
for name in self.compartment_names:
7789
self.compartments[name] = tf.Variable(tf.zeros([batch_size,dim]),
@@ -82,8 +94,6 @@ def __init__(self, name, dim, gain=1.0, batch_size=1, trace_kernel=None):
8294
self.masks[name] = tf.Variable(tf.ones([batch_size,dim]),
8395
name="{}_{}".format(self.name, name))
8496

85-
self.connected_cables = []
86-
8797
def compile(self):
8898
info = super().compile()
8999
#info["leak"] = self.leak
@@ -98,41 +108,48 @@ def step(self, injection_table=None, skip_core_calc=False):
98108
bmask = self.masks.get("mask")
99109
########################################################################
100110
if skip_core_calc == False:
111+
# compute spike response model
112+
dt = self.constants.get("dt") # integration time constant
101113
z = self.compartments.get("z")
114+
self.t = self.t + dt # advance time forward by dt (t <- t + dt)
102115
#Sz = transform.convert_to_spikes(z, self.max_spike_rate, self.dt)
103116
Sz = stat.convert_to_spikes(z, gain=self.gain)
104117

105118
if injection_table.get("Sz") is None:
106119
if self.do_inplace == True:
120+
t_spike = self.compartments.get("t_spike")
121+
t_spike = t_spike * (1.0 - Sz) + (Sz * self.t)
107122
self.compartments["Sz"].assign(Sz)
123+
self.compartments["t_spike"].assign(t_spike)
108124
else:
125+
t_spike = self.compartments.get("t_spike")
126+
t_spike = t_spike * (1.0 - Sz) + (Sz * self.t)
109127
self.compartments["Sz"] = Sz
110-
##########################################################################
111-
112-
##########################################################################
113-
trace_alpha = self.constants.get("trace_alpha")
114-
trace_z_tm1 = self.compartments.get("Trace_z")
115-
# apply variable trace filters z_l(t) = (alpha * z_l(t))*(1−s`(t)) +s_l(t)
116-
trace_z = tf.add((trace_z_tm1 * trace_alpha) * (-Sz + 1.0), Sz)
128+
self.compartments["t_spike"] = t_spike
129+
130+
#### update trace variable ####
131+
tau_tr = self.constants.get("tau_trace")
132+
#Sz = self.compartments.get("Sz")
133+
tr_z = self.compartments.get("Trace_z")
134+
d_tr = -tr_z/tau_tr + Sz
135+
tr_z = tr_z + d_tr
117136
if injection_table.get("Trace_z") is None:
118137
if self.do_inplace == True:
119-
self.compartments["Trace_z"].assign(trace_z)
138+
self.compartments["Trace_z"].assign(tr_z)
120139
else:
121-
self.compartments["Trace_z"] = trace_z
122-
# Ns = self.compartments.get("Ns")
123-
# x_tar = self.compartments.get("x_tar")
124-
# x_tar = x_tar + (trace_z - x_tar)/Ns
125-
# if injection_table.get("x_tar") is None:
126-
# if self.do_inplace == True:
127-
# self.compartments["x_tar"].assign(x_tar)
128-
# else:
129-
# self.compartments["x_tar"] = x_tar
130-
# Ns = Ns + 1
131-
# if injection_table.get("Ns") is None:
140+
self.compartments["Trace_z"] = tr_z
141+
##########################################################################
142+
143+
##########################################################################
144+
# trace_alpha = self.constants.get("trace_alpha")
145+
# trace_z_tm1 = self.compartments.get("Trace_z")
146+
# # apply variable trace filters z_l(t) = (alpha * z_l(t))*(1−s`(t)) +s_l(t)
147+
# trace_z = tf.add((trace_z_tm1 * trace_alpha) * (-Sz + 1.0), Sz)
148+
# if injection_table.get("Trace_z") is None:
132149
# if self.do_inplace == True:
133-
# self.compartments["Ns"].assign(Ns)
150+
# self.compartments["Trace_z"].assign(trace_z)
134151
# else:
135-
# self.compartments["Ns"] = Ns
152+
# self.compartments["Trace_z"] = trace_z
136153

137154
if bmask is not None: # applies mask to all component variables of this node
138155
for key in self.compartments:
@@ -143,8 +160,6 @@ def step(self, injection_table=None, skip_core_calc=False):
143160
self.compartments[key] = ( self.compartments.get(key) * bmask )
144161

145162
########################################################################
146-
if skip_core_calc == False:
147-
self.t = self.t + 1
148163

149164
# a node returns a list of its named component values
150165
values = []

0 commit comments

Comments
 (0)