22from jax import numpy as jnp , random , jit , nn
33from functools import partial
44from ngclearn .utils import tensorstats
5- from ngcsimlib . deprecators import deprecate_args
5+ from ngcsimlib import deprecate_args
66from ngcsimlib .logger import info , warn
77from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
88 step_euler , step_rk2
9- from ngclearn .utils .surrogate_fx import (secant_lif_estimator , arctan_estimator ,
10- triangular_estimator ,
11- straight_through_estimator )
9+ # from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10+ # triangular_estimator,
11+ # straight_through_estimator)
1212
13- from ngcsimlib .compilers .process import transition
14- #from ngcsimlib.component import Component
13+ from ngcsimlib .parser import compilable
1514from ngcsimlib .compartment import Compartment
1615
1716
@@ -35,7 +34,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
3534 The specific differential equation that characterizes this cell
3635 is (for adjusting v, given current j, over time) is:
3736
38- | tau_m * dv/dt = (v_rest - v) + j * R
37+ | tau_m * dv/dt = j * R
3938 | where R is the membrane resistance and v_rest is the resting potential
4039 | also, if a spike occurs, v is set to v_reset
4140
@@ -91,10 +90,10 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
9190 """
9291
9392 @deprecate_args (thr_jitter = None )
94- def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
95- v_reset = - 60. , refract_time = 0. , integration_type = "euler" ,
96- surrogate_type = "straight_through" , lower_clamp_voltage = True ,
97- ** kwargs ):
93+ def __init__ (
94+ self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. , v_reset = - 60. , refract_time = 0. ,
95+ integration_type = "euler" , surrogate_type = "straight_through" , lower_clamp_voltage = True , ** kwargs
96+ ):
9897 super ().__init__ (name , ** kwargs )
9998
10099 ## Integration properties
@@ -118,12 +117,12 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
118117 self .n_units = n_units
119118
120119 ## set up surrogate function for spike emission
121- if surrogate_type == "arctan" :
122- self .spike_fx , self .d_spike_fx = arctan_estimator ()
123- elif surrogate_type == "triangular" :
124- self .spike_fx , self .d_spike_fx = triangular_estimator ()
125- else : ## default: straight_through
126- self .spike_fx , self .d_spike_fx = straight_through_estimator ()
120+ # if surrogate_type == "arctan":
121+ # self.spike_fx, self.d_spike_fx = arctan_estimator()
122+ # elif surrogate_type == "triangular":
123+ # self.spike_fx, self.d_spike_fx = triangular_estimator()
124+ # else: ## default: straight_through
125+ # self.spike_fx, self.d_spike_fx = straight_through_estimator()
127126
128127
129128 ## Compartment setup
@@ -138,47 +137,48 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
138137 units = "ms" ) ## time-of-last-spike
139138 self .surrogate = Compartment (restVals + 1. , display_name = "Surrogate State Value" )
140139
141- @transition (output_compartments = ["v" , "s" , "rfr" , "tols" , "key" , "surrogate" ])
142- @staticmethod
140+ @compilable
143141 def advance_state (
144- t , dt , tau_m , resist_m , v_rest , v_reset , refract_T , thr , lower_clamp_voltage , intgFlag , d_spike_fx , key ,
145- j , v , rfr , tols
142+ self , dt , t
146143 ):
147144 ## run one integration step for neuronal dynamics
148- j = j * resist_m
145+ j = self . j . get () * self . resist_m
149146
150147 ### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
151148 ## update voltage / membrane potential
152- v_params = (j , rfr , tau_m , refract_T )
153- if intgFlag == 1 :
154- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
149+ v_params = (j , self . rfr . get (), self . tau_m , self . refract_T )
150+ if self . intgFlag == 1 :
151+ _ , _v = step_rk2 (0. , self . v . get () , _dfv , dt , v_params )
155152 else :
156- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
153+ _ , _v = step_euler (0. , self . v . get () , _dfv , dt , v_params )
157154 ## obtain action potentials/spikes
158- s = (_v > thr ) * 1.
155+ s = (_v > self . thr ) * 1.
159156 ## update refractory variables
160- rfr = (rfr + dt ) * (1. - s )
157+ rfr = (self . rfr . get () + dt ) * (1. - s )
161158 ## perform hyper-polarization of neuronal cells
162- v = _v * (1. - s ) + s * v_reset
159+ v = _v * (1. - s ) + s * self .v_reset
160+
161+ #surrogate = d_spike_fx(v, self.thr)
163162
164- surrogate = d_spike_fx (v , thr )
165163 ## update tols
166- tols = (1. - s ) * tols + (s * t )
167- if lower_clamp_voltage : ## ensure voltage never < v_rest
168- v = jnp .maximum (v , v_rest )
169- return v , s , rfr , tols , key , surrogate
170-
171- @transition (output_compartments = ["j" , "v" , "s" , "rfr" , "tols" , "surrogate" ])
172- @staticmethod
173- def reset (batch_size , n_units , v_rest , refract_T ):
174- restVals = jnp .zeros ((batch_size , n_units ))
175- j = restVals #+ 0
176- v = restVals + v_rest
177- s = restVals #+ 0
178- rfr = restVals + refract_T
179- tols = restVals #+ 0
180- surrogate = restVals + 1.
181- return j , v , s , rfr , tols , surrogate
164+ self .tols .set ((1. - s ) * self .tols .get () + (s * t ))
165+ if self .lower_clamp_voltage : ## ensure voltage never < v_rest
166+ _v = jnp .maximum (v , self .v_rest )
167+
168+ self .v .set (_v )
169+ self .s .set (s )
170+ self .rfr .set (rfr )
171+
172+ @compilable
173+ def reset (self ):
174+ restVals = jnp .zeros ((self .batch_size , self .n_units ))
175+ if not self .j .targeted :
176+ self .j .set (restVals )
177+ self .v .set (restVals + self .v_rest )
178+ self .s .set (restVals )
179+ self .rfr .set (restVals + self .refract_T )
180+ self .tols .set (restVals )
181+ #surrogate = restVals + 1.
182182
183183 def save (self , directory , ** kwargs ):
184184 ## do a protected save of constants, depending on whether they are floats or arrays
0 commit comments