1
1
from ngclearn .components .jaxComponent import JaxComponent
2
- from jax import numpy as jnp , random , jit , nn
2
+ from jax import numpy as jnp , random , jit , nn , Array
3
3
from functools import partial
4
4
from ngclearn .utils import tensorstats
5
- from ngcsimlib . deprecators import deprecate_args
5
+ from ngcsimlib import deprecate_args
6
6
from ngcsimlib .logger import info , warn
7
7
from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
8
8
step_euler , step_rk2
9
- from ngclearn .utils .surrogate_fx import (secant_lif_estimator , arctan_estimator ,
10
- triangular_estimator ,
11
- straight_through_estimator )
9
+ # from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10
+ # triangular_estimator,
11
+ # straight_through_estimator)
12
12
13
- from ngcsimlib .compilers .process import transition
14
- #from ngcsimlib.component import Component
13
+ from ngcsimlib .parser import compilable
15
14
from ngcsimlib .compartment import Compartment
16
15
17
16
from ngclearn .components .neurons .spiking .LIFCell import LIFCell
@@ -30,7 +29,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
30
29
return dv_dt
31
30
32
31
#@partial(jit, static_argnums=[3, 4])
33
- def _update_theta (dt , v_theta , s , tau_theta , theta_plus = 0.05 ):
32
+ def _update_theta (dt , v_theta , s , tau_theta , theta_plus : Array | float = 0.05 ):
34
33
### Runs homeostatic threshold update dynamics one step (via Euler integration).
35
34
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
36
35
#theta_plus = 0.05
@@ -133,71 +132,69 @@ def __init__(
133
132
self .v_c = v_scale
134
133
self .a0 = critical_v
135
134
136
- @transition (output_compartments = ["v" , "s" , "s_raw" , "rfr" , "thr_theta" , "tols" , "key" , "surrogate" ])
137
- @staticmethod
138
- def advance_state (
139
- t , dt , tau_m , resist_m , v_rest , v_reset , v_c , a0 , refract_T , thr , tau_theta , theta_plus ,
140
- one_spike , lower_clamp_voltage , intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols
141
- ):
142
- skey = None ## this is an empty dkey if single_spike mode turned off
143
- if one_spike :
144
- key , skey = random .split (key , 2 )
145
- ## run one integration step for neuronal dynamics
146
- j = j * resist_m
147
- ############################################################################
148
- ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
149
- _v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
150
- #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
151
- ## update voltage / membrane potential
152
- v_params = (j , rfr , tau_m , refract_T , v_rest , v_c , a0 )
153
- if intgFlag == 1 :
154
- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
155
- else : #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
156
- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
157
- ## obtain action potentials/spikes
135
+ @compilable
136
+ def advance_state (self , dt , t ):
137
+ j = self .j .get () * self .resist_m
138
+
139
+ _v_thr = self .thr_theta .get () + self .thr ## calc present voltage threshold
140
+
141
+ v_params = (j , self .rfr .get (), self .tau_m .get (), self .refract_T , self .v_rest , self .v_c , self .a0 )
142
+
143
+ if self .intgFlag == 1 :
144
+ _ , _v = step_rk2 (0. , self .v .get (), _dfv , dt , v_params )
145
+ else :
146
+ _ , _v = step_euler (0. , self .v .get (), _dfv , dt , v_params )
147
+
158
148
s = (_v > _v_thr ) * 1.
159
- ## update refractory variables
160
- _rfr = (rfr + dt ) * (1. - s )
161
- ## perform hyper-polarization of neuronal cells
162
- _v = _v * (1. - s ) + s * v_reset
163
-
164
- raw_s = s + 0 ## preserve un-altered spikes
165
- ############################################################################
166
- ## this is a spike post-processing step
167
- if skey is not None :
168
- m_switch = (jnp .sum (s ) > 0. ).astype (jnp .float32 ) ## TODO: not batch-able
149
+ _rfr = (self .rfr .get () + dt ) * (1. - s )
150
+ _v = _v * (1. - s ) + s * self .v_reset
151
+
152
+ raw_s = s
153
+
154
+ #surrogate = d_spike_fx(v, _v_thr) # d_spike_fx(v, thr + thr_theta)
155
+
156
+ if self .one_spike and not self .max_one_spike :
157
+ key , skey = random .split (self .key .get (), 2 )
158
+
159
+ m_switch = (jnp .sum (s ) > 0. ).astype (jnp .float32 ) ## TODO: not batch-able
169
160
rS = s * random .uniform (skey , s .shape )
170
161
rS = nn .one_hot (jnp .argmax (rS , axis = 1 ), num_classes = s .shape [1 ],
171
162
dtype = jnp .float32 )
172
163
s = s * (1. - m_switch ) + rS * m_switch
173
- ############################################################################
174
- raw_spikes = raw_s
175
- v = _v
176
- rfr = _rfr
164
+ self .key .set (key )
177
165
178
- surrogate = d_spike_fx (v , _v_thr ) #d_spike_fx(v, thr + thr_theta)
179
- if tau_theta > 0. :
166
+ if self .max_one_spike :
167
+ rS = nn .one_hot (jnp .argmax (self .v .get (), axis = 1 ), num_classes = s .shape [1 ],
168
+ dtype = jnp .float32 ) ## get max-volt spike
169
+ s = s * rS ## mask out non-max volt spikes
170
+
171
+ if self .tau_theta > 0. :
180
172
## run one integration step for threshold dynamics
181
- thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta , theta_plus )
173
+ thr_theta = _update_theta (dt , self .thr_theta .get (), raw_s , self .tau_theta , self .theta_plus ) # .get())
174
+ self .thr_theta .set (thr_theta )
175
+
182
176
## update tols
183
- tols = (1. - s ) * tols + (s * t )
184
- if lower_clamp_voltage : ## ensure voltage never < v_rest
185
- v = jnp .maximum (v , v_rest )
186
- return v , s , raw_spikes , rfr , thr_theta , tols , key , surrogate
187
-
188
- @transition (output_compartments = ["j" , "v" , "s" , "s_raw" , "rfr" , "tols" , "surrogate" ])
189
- @staticmethod
190
- def reset (batch_size , n_units , v_rest , refract_T ):
191
- restVals = jnp .zeros ((batch_size , n_units ))
192
- j = restVals #+ 0
193
- v = restVals + v_rest
194
- s = restVals #+ 0
195
- s_raw = restVals
196
- rfr = restVals + refract_T
197
- #thr_theta = restVals ## do not reset thr_theta
198
- tols = restVals #+ 0
199
- surrogate = restVals + 1.
200
- return j , v , s , s_raw , rfr , tols , surrogate
177
+ self .tols .set ((1. - s ) * self .tols .get () + (s * t ))
178
+
179
+ if self .v_min is not None : ## ensures voltage never < v_rest
180
+ _v = jnp .maximum (_v , self .v_min )
181
+
182
+ self .v .set (_v )
183
+ self .s .set (s )
184
+ self .s_raw .set (raw_s )
185
+ self .rfr .set (_rfr )
186
+
187
+ @compilable
188
+ def reset (self ):
189
+ restVals = jnp .zeros ((self .batch_size , self .n_units ))
190
+ if not self .j .targeted :
191
+ self .j .set (restVals )
192
+ self .v .set (restVals + self .v_rest )
193
+ self .s .set (restVals )
194
+ self .s_raw .set (restVals )
195
+ self .rfr .set (restVals + self .refract_T )
196
+ self .tols .set (restVals )
197
+ #self.surrogate.set(restVals)
201
198
202
199
def save (self , directory , ** kwargs ):
203
200
## do a protected save of constants, depending on whether they are floats or arrays
0 commit comments