Skip to content

Commit 2fc2300

Browse files
author
Alexander Ororbia
committed
ported over quad-lif to v3 - needs testing
1 parent d2d4331 commit 2fc2300

File tree

2 files changed

+63
-66
lines changed

2 files changed

+63
-66
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def advance_state(self, dt, t):
202202

203203
if self.tau_theta > 0.:
204204
## run one integration step for threshold dynamics
205-
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus.get())
205+
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
206206
self.thr_theta.set(thr_theta)
207207

208208
## update tols

ngclearn/components/neurons/spiking/quadLIFCell.py

Lines changed: 62 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from ngclearn.components.jaxComponent import JaxComponent
2-
from jax import numpy as jnp, random, jit, nn
2+
from jax import numpy as jnp, random, jit, nn, Array
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
77
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
88
step_euler, step_rk2
9-
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10-
triangular_estimator,
11-
straight_through_estimator)
9+
# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10+
# triangular_estimator,
11+
# straight_through_estimator)
1212

13-
from ngcsimlib.compilers.process import transition
14-
#from ngcsimlib.component import Component
13+
from ngcsimlib.parser import compilable
1514
from ngcsimlib.compartment import Compartment
1615

1716
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
@@ -30,7 +29,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
3029
return dv_dt
3130

3231
#@partial(jit, static_argnums=[3, 4])
33-
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
32+
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array | float=0.05):
3433
### Runs homeostatic threshold update dynamics one step (via Euler integration).
3534
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
3635
#theta_plus = 0.05
@@ -133,71 +132,69 @@ def __init__(
133132
self.v_c = v_scale
134133
self.a0 = critical_v
135134

136-
@transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
137-
@staticmethod
138-
def advance_state(
139-
t, dt, tau_m, resist_m, v_rest, v_reset, v_c, a0, refract_T, thr, tau_theta, theta_plus,
140-
one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
141-
):
142-
skey = None ## this is an empty dkey if single_spike mode turned off
143-
if one_spike:
144-
key, skey = random.split(key, 2)
145-
## run one integration step for neuronal dynamics
146-
j = j * resist_m
147-
############################################################################
148-
### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
149-
_v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
150-
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
151-
## update voltage / membrane potential
152-
v_params = (j, rfr, tau_m, refract_T, v_rest, v_c, a0)
153-
if intgFlag == 1:
154-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
155-
else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
156-
_, _v = step_euler(0., v, _dfv, dt, v_params)
157-
## obtain action potentials/spikes
135+
@compilable
136+
def advance_state(self, dt, t):
137+
j = self.j.get() * self.resist_m
138+
139+
_v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold
140+
141+
v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.v_c, self.a0)
142+
143+
if self.intgFlag == 1:
144+
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
145+
else:
146+
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
147+
158148
s = (_v > _v_thr) * 1.
159-
## update refractory variables
160-
_rfr = (rfr + dt) * (1. - s)
161-
## perform hyper-polarization of neuronal cells
162-
_v = _v * (1. - s) + s * v_reset
163-
164-
raw_s = s + 0 ## preserve un-altered spikes
165-
############################################################################
166-
## this is a spike post-processing step
167-
if skey is not None:
168-
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
149+
_rfr = (self.rfr.get() + dt) * (1. - s)
150+
_v = _v * (1. - s) + s * self.v_reset
151+
152+
raw_s = s
153+
154+
#surrogate = d_spike_fx(v, _v_thr) # d_spike_fx(v, thr + thr_theta)
155+
156+
if self.one_spike and not self.max_one_spike:
157+
key, skey = random.split(self.key.get(), 2)
158+
159+
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
169160
rS = s * random.uniform(skey, s.shape)
170161
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
171162
dtype=jnp.float32)
172163
s = s * (1. - m_switch) + rS * m_switch
173-
############################################################################
174-
raw_spikes = raw_s
175-
v = _v
176-
rfr = _rfr
164+
self.key.set(key)
177165

178-
surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta)
179-
if tau_theta > 0.:
166+
if self.max_one_spike:
167+
rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1],
168+
dtype=jnp.float32) ## get max-volt spike
169+
s = s * rS ## mask out non-max volt spikes
170+
171+
if self.tau_theta > 0.:
180172
## run one integration step for threshold dynamics
181-
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
173+
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) # .get())
174+
self.thr_theta.set(thr_theta)
175+
182176
## update tols
183-
tols = (1. - s) * tols + (s * t)
184-
if lower_clamp_voltage: ## ensure voltage never < v_rest
185-
v = jnp.maximum(v, v_rest)
186-
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
187-
188-
@transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])
189-
@staticmethod
190-
def reset(batch_size, n_units, v_rest, refract_T):
191-
restVals = jnp.zeros((batch_size, n_units))
192-
j = restVals #+ 0
193-
v = restVals + v_rest
194-
s = restVals #+ 0
195-
s_raw = restVals
196-
rfr = restVals + refract_T
197-
#thr_theta = restVals ## do not reset thr_theta
198-
tols = restVals #+ 0
199-
surrogate = restVals + 1.
200-
return j, v, s, s_raw, rfr, tols, surrogate
177+
self.tols.set((1. - s) * self.tols.get() + (s * t))
178+
179+
if self.v_min is not None: ## ensures voltage never < v_rest
180+
_v = jnp.maximum(_v, self.v_min)
181+
182+
self.v.set(_v)
183+
self.s.set(s)
184+
self.s_raw.set(raw_s)
185+
self.rfr.set(_rfr)
186+
187+
@compilable
188+
def reset(self):
189+
restVals = jnp.zeros((self.batch_size, self.n_units))
190+
if not self.j.targeted:
191+
self.j.set(restVals)
192+
self.v.set(restVals + self.v_rest)
193+
self.s.set(restVals)
194+
self.s_raw.set(restVals)
195+
self.rfr.set(restVals + self.refract_T)
196+
self.tols.set(restVals)
197+
#self.surrogate.set(restVals)
201198

202199
def save(self, directory, **kwargs):
203200
## do a protected save of constants, depending on whether they are floats or arrays

0 commit comments

Comments
 (0)