Skip to content

Commit 92b6940

Browse files
author
Alexander Ororbia
committed
ported over IF/quadLIF cells, minor revision to LIF cell
1 parent 2fc2300 commit 92b6940

File tree

3 files changed

+49
-49
lines changed

3 files changed

+49
-49
lines changed

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
from jax import numpy as jnp, random, jit, nn
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
77
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
88
step_euler, step_rk2
9-
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10-
triangular_estimator,
11-
straight_through_estimator)
9+
# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10+
# triangular_estimator,
11+
# straight_through_estimator)
1212

13-
from ngcsimlib.compilers.process import transition
14-
#from ngcsimlib.component import Component
13+
from ngcsimlib.parser import compilable
1514
from ngcsimlib.compartment import Compartment
1615

1716

@@ -35,7 +34,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
3534
The specific differential equation that characterizes this cell
3635
is (for adjusting v, given current j, over time) is:
3736
38-
| tau_m * dv/dt = (v_rest - v) + j * R
37+
| tau_m * dv/dt = j * R
3938
| where R is the membrane resistance and v_rest is the resting potential
4039
| also, if a spike occurs, v is set to v_reset
4140
@@ -91,10 +90,10 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
9190
"""
9291

9392
@deprecate_args(thr_jitter=None)
94-
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
95-
v_reset=-60., refract_time=0., integration_type="euler",
96-
surrogate_type="straight_through", lower_clamp_voltage=True,
97-
**kwargs):
93+
def __init__(
94+
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0.,
95+
integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs
96+
):
9897
super().__init__(name, **kwargs)
9998

10099
## Integration properties
@@ -118,12 +117,12 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
118117
self.n_units = n_units
119118

120119
## set up surrogate function for spike emission
121-
if surrogate_type == "arctan":
122-
self.spike_fx, self.d_spike_fx = arctan_estimator()
123-
elif surrogate_type == "triangular":
124-
self.spike_fx, self.d_spike_fx = triangular_estimator()
125-
else: ## default: straight_through
126-
self.spike_fx, self.d_spike_fx = straight_through_estimator()
120+
# if surrogate_type == "arctan":
121+
# self.spike_fx, self.d_spike_fx = arctan_estimator()
122+
# elif surrogate_type == "triangular":
123+
# self.spike_fx, self.d_spike_fx = triangular_estimator()
124+
# else: ## default: straight_through
125+
# self.spike_fx, self.d_spike_fx = straight_through_estimator()
127126

128127

129128
## Compartment setup
@@ -138,47 +137,48 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
138137
units="ms") ## time-of-last-spike
139138
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
140139

141-
@transition(output_compartments=["v", "s", "rfr", "tols", "key", "surrogate"])
142-
@staticmethod
140+
@compilable
143141
def advance_state(
144-
t, dt, tau_m, resist_m, v_rest, v_reset, refract_T, thr, lower_clamp_voltage, intgFlag, d_spike_fx, key,
145-
j, v, rfr, tols
142+
self, dt, t
146143
):
147144
## run one integration step for neuronal dynamics
148-
j = j * resist_m
145+
j = self.j.get() * self.resist_m
149146

150147
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
151148
## update voltage / membrane potential
152-
v_params = (j, rfr, tau_m, refract_T)
153-
if intgFlag == 1:
154-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
149+
v_params = (j, self.rfr.get(), self.tau_m, self.refract_T)
150+
if self.intgFlag == 1:
151+
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
155152
else:
156-
_, _v = step_euler(0., v, _dfv, dt, v_params)
153+
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
157154
## obtain action potentials/spikes
158-
s = (_v > thr) * 1.
155+
s = (_v > self.thr) * 1.
159156
## update refractory variables
160-
rfr = (rfr + dt) * (1. - s)
157+
rfr = (self.rfr.get() + dt) * (1. - s)
161158
## perform hyper-polarization of neuronal cells
162-
v = _v * (1. - s) + s * v_reset
159+
v = _v * (1. - s) + s * self.v_reset
160+
161+
#surrogate = d_spike_fx(v, self.thr)
163162

164-
surrogate = d_spike_fx(v, thr)
165163
## update tols
166-
tols = (1. - s) * tols + (s * t)
167-
if lower_clamp_voltage: ## ensure voltage never < v_rest
168-
v = jnp.maximum(v, v_rest)
169-
return v, s, rfr, tols, key, surrogate
170-
171-
@transition(output_compartments=["j", "v", "s", "rfr", "tols", "surrogate"])
172-
@staticmethod
173-
def reset(batch_size, n_units, v_rest, refract_T):
174-
restVals = jnp.zeros((batch_size, n_units))
175-
j = restVals #+ 0
176-
v = restVals + v_rest
177-
s = restVals #+ 0
178-
rfr = restVals + refract_T
179-
tols = restVals #+ 0
180-
surrogate = restVals + 1.
181-
return j, v, s, rfr, tols, surrogate
164+
self.tols.set((1. - s) * self.tols.get() + (s * t))
165+
if self.lower_clamp_voltage: ## ensure voltage never < v_rest
166+
_v = jnp.maximum(v, self.v_rest)
167+
168+
self.v.set(_v)
169+
self.s.set(s)
170+
self.rfr.set(rfr)
171+
172+
@compilable
173+
def reset(self):
174+
restVals = jnp.zeros((self.batch_size, self.n_units))
175+
if not self.j.targeted:
176+
self.j.set(restVals)
177+
self.v.set(restVals + self.v_rest)
178+
self.s.set(restVals)
179+
self.rfr.set(restVals + self.refract_T)
180+
self.tols.set(restVals)
181+
#surrogate = restVals + 1.
182182

183183
def save(self, directory, **kwargs):
184184
## do a protected save of constants, depending on whether they are floats or arrays

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
1919

2020

2121
#@partial(jit, static_argnums=[3, 4])
22-
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
22+
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
2323
### Runs homeostatic threshold update dynamics one step (via Euler integration).
2424
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
2525
#theta_plus = 0.05

ngclearn/components/neurons/spiking/quadLIFCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
2929
return dv_dt
3030

3131
#@partial(jit, static_argnums=[3, 4])
32-
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array | float=0.05):
32+
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
3333
### Runs homeostatic threshold update dynamics one step (via Euler integration).
3434
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
3535
#theta_plus = 0.05
@@ -138,7 +138,7 @@ def advance_state(self, dt, t):
138138

139139
_v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold
140140

141-
v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.v_c, self.a0)
141+
v_params = (j, self.rfr.get(), self.tau_m, self.refract_T, self.v_rest, self.v_c, self.a0)
142142

143143
if self.intgFlag == 1:
144144
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)

0 commit comments

Comments
 (0)