From 94f37f7e59d2c75e12afc3aa090d1a3c8ebec20a Mon Sep 17 00:00:00 2001 From: ago109 Date: Fri, 9 Aug 2024 12:38:05 -0400 Subject: [PATCH] cleaned up raf-cell --- ngclearn/components/neurons/spiking/RAFCell.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index af49f00e..b550aee7 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -97,13 +97,13 @@ class RAFCell(JaxComponent): b: oscillation dampening factor (Default: -1) - v_reset: membrane potential reset condition (Default: 1 mV) + v_reset: reset condition for membrane potential (Default: 1 mV) w_reset: reset condition for angular current driver (Default: 0) - v0: membrane potential initial condition (Default: 1 mV) + v0: initial condition for membrane potential (Default: 1 mV) - w0: angular driver initial condition (Default: 0) + w0: initial condition for angular current driver (Default: 0) resist_v: membrane resistance (Default: 1 mega-Ohm) @@ -185,11 +185,11 @@ def advance_state(self, j, v, w, s, tols): self.tols.set(tols) @staticmethod - def _reset(batch_size, n_units, v_reset, w_reset): + def _reset(batch_size, n_units, v0, w0): restVals = jnp.zeros((batch_size, n_units)) j = restVals # None - v = restVals + v_reset - w = restVals + w_reset + v = restVals + v0 + w = restVals + w0 s = restVals #+ 0 tols = restVals #+ 0 return j, v, w, s, tols