Skip to content

Major release update (to 2.0.0) #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 115 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
250b2bb
add initial patch mask features
rxng8 Mar 17, 2025
3804426
minor edit to bern-cell
Mar 21, 2025
8a5bc68
fixed bernoulli error cell
Mar 22, 2025
1f51c2e
Merge branches 'patch_mask_utils' and 'major_release_update' of githu…
rxng8 Mar 25, 2025
b51cfd0
example rate cell test
rxng8 Mar 25, 2025
f4d47d4
made some corrections to bern err-cell and heb syn
Mar 25, 2025
e408d9b
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
Mar 25, 2025
0ffb3d1
made some corrections to bern err-cell and heb syn
Mar 25, 2025
1c5164e
cleaned up bern-cell, hebb-syn
Mar 26, 2025
c4fe072
minor mod to model-utils
Mar 26, 2025
21f43f8
attempted rewrite of bernoulli-cell
Mar 26, 2025
7c67169
got bernoulli-cell rewritten and unit-tested
Mar 26, 2025
bd5b88d
edit to bern-cell
Mar 26, 2025
479d94a
bernoulli and poisson cells revised, unit-tested
Mar 26, 2025
74840d9
latency-cell refactored and unit-tested
Mar 26, 2025
adc74cf
refactored Rate Cell
rxng8 Mar 27, 2025
10dc640
minor revisions to input-encoders, revised phasor-cell w/ unit-test
Mar 29, 2025
0a803ff
revised and add unit-test for varTrace
Mar 29, 2025
d850f9d
revised and added unit-test for exp-kernel
Mar 29, 2025
0d24653
revised and added unit-test for exp-kernel
Mar 29, 2025
6af9fd4
revised slif cell w/ unit-test; needed mod to diffeq
Mar 29, 2025
b077ee0
revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/…
Mar 29, 2025
d4dfe38
revised lif-cell w/ unit-test
Mar 29, 2025
f6d56a4
revised unit-tests to pass globally; some minor patches to phasor-cel…
Mar 30, 2025
68ccc09
minor cleanup of unit-test for phasor
Mar 30, 2025
f4e661c
revised if-cell w/ unit-test
Mar 31, 2025
3c1ee98
revised if-cell w/ unit-test
Mar 31, 2025
db25673
revised quad-lif w/ unit-test
Mar 31, 2025
aab5e3b
revised adex-cell w/ unit test, minor cleanup of quad-lif
Mar 31, 2025
eaafb64
minor edit to adex unit-test
Mar 31, 2025
b30c8fe
refactor bernoulli, laplacian, and rewarderror cells
rxng8 Mar 31, 2025
26e27c4
revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells
Mar 31, 2025
b04822a
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
Mar 31, 2025
06d4a53
revised wtas-cell w/ unit test
Mar 31, 2025
0ed77fa
revised fh-cell w/ unit test
Mar 31, 2025
9ea587a
revised izh-cell w/ unit test
Mar 31, 2025
098f3db
patched ode_utils backend wrt jax, cleaned up unit-tests, added disab…
Mar 31, 2025
8a5958d
update rate cell
rxng8 Apr 1, 2025
f453623
fix test rate cell
rxng8 Apr 1, 2025
3193b72
update test for bernoulli cell
rxng8 Apr 1, 2025
55e9fc7
update refactoring for gaussian error cell
rxng8 Apr 1, 2025
4c22428
update unit testing for all graded neurons
rxng8 Apr 1, 2025
b3c47a2
wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up e…
Apr 1, 2025
af283dc
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
Apr 1, 2025
bfc200c
added rk2 support for H-H cell
Apr 1, 2025
43fbd9b
update rate cell and fix bug of passing a tuple of (jax Array -- not …
rxng8 Apr 1, 2025
6119846
update test rate cell
rxng8 Apr 1, 2025
9478695
refactored dense and trace-stdp syn w/ unit-test
Apr 1, 2025
6a4889a
refactored exp-stdp syn w/ unit-test
Apr 1, 2025
7fbae79
refactored event-stdp w/ unit-test
Apr 1, 2025
29b49ff
cleanup of stdp-syn
Apr 1, 2025
5272fdc
refactored bcm syn w/ unit-test
Apr 1, 2025
58c1d30
refactored stp-syn with unit-test
Apr 2, 2025
17540ea
cleaned up modulated
Apr 2, 2025
464ab10
refactored mstdp-et syn w/ unit-test
Apr 2, 2025
33e0cc1
refactored lava components to new sim-lib
Apr 2, 2025
5c56389
refactored conv/hebb-conv syn w/ unit-test
Apr 2, 2025
ff2628c
refactored/revised hebb-deconv syn w/ unit-test
Apr 2, 2025
ce13554
revised/refactored hebb/stdp conv/deconv syn w/ unit-tests
Apr 2, 2025
5bd93c2
updated modeling doc to point to hodgkin-huxley cell
Apr 2, 2025
620fb4a
updated modeling docs
Apr 2, 2025
e9e314d
fixed typo in adex-cell tutorial doc
Apr 2, 2025
ea3396d
revised tutorials to reflect new sim-lib config/syntax
Apr 2, 2025
e981e1d
revised tutorials to reflect new sim-lib config/syntax
Apr 2, 2025
d536e33
patched docs to reflect revisions/refactor
Apr 2, 2025
9f4f7f9
tweaked requirements in prep for major release
Apr 2, 2025
f480ca2
cleaned up a few unit tests to use deterministic syn init vals
Apr 3, 2025
f42b8a2
mod to requirements
Apr 3, 2025
d6860a8
nudge toml to upcoming 2.0.0
Apr 3, 2025
8317d3f
update to support docs in prep for 2.0.0
Apr 3, 2025
bc713f6
update patched synapses and their test cases
rxng8 Apr 3, 2025
4a71866
cleaned up syn modeling doc
Apr 3, 2025
e3791df
push hebbian synapse
rxng8 Apr 3, 2025
1b45362
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
rxng8 Apr 3, 2025
7dbaccd
push reinforce synapse
rxng8 Apr 4, 2025
9748000
push np seed
rxng8 Apr 4, 2025
5da1549
patched minor prior None arg issue in hebb-syn
Apr 4, 2025
242161e
moved reinforce-syn to right spot
Apr 4, 2025
b154e4e
update reinforce synapse and testing
rxng8 Apr 5, 2025
c5a13d4
tweaked trace-stdp and mstdpet
Apr 5, 2025
97619e5
patched mstdpet unit-test
Apr 5, 2025
c3f39f1
update reinforce synapse and test cases
rxng8 Apr 6, 2025
92def68
add reinforce synapse fix
rxng8 Apr 6, 2025
fe28a03
minor mod to mstdpet
Apr 6, 2025
40812ff
update test code for more than 1 steps
rxng8 Apr 6, 2025
d352293
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
rxng8 Apr 6, 2025
4e9176c
Updated monitors
willgebhardt Apr 7, 2025
f556293
patched tests to use process naming
Apr 7, 2025
9b75266
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
Apr 7, 2025
8dbf83a
Added wrapper for reset and advance_state
willgebhardt Apr 7, 2025
cd9f12a
Added a JaxProcess
willgebhardt Apr 7, 2025
b594529
update the old rate cell
rxng8 Apr 7, 2025
2e1b2e8
Merge branches 'major_release_update' and 'major_release_update' of g…
rxng8 Apr 7, 2025
20230ae
update old hebbian synapse
rxng8 Apr 7, 2025
f4c2e7d
minor edit to if-cell
Apr 9, 2025
0b0a508
ported over adex tutorial to new ngclearn format
Apr 9, 2025
389c7aa
hh-cell supports rk4 integration
Apr 9, 2025
d094948
clean up and integrated hodgkin-huxley mini lesson in neurocog tutorials
Apr 9, 2025
7d6d841
Update jaxProcess.py
willgebhardt Apr 10, 2025
e9d7068
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
willgebhardt Apr 10, 2025
e7f482b
update working reinforce synapse
rxng8 Apr 10, 2025
91da161
update correct reinforce and testing
rxng8 Apr 12, 2025
1110d98
update documentation
rxng8 Apr 12, 2025
6bbdb82
update features, documentation, and testing
rxng8 Apr 12, 2025
701b501
update testing for REINFORCE cell
rxng8 Apr 12, 2025
68d2435
update code and test
rxng8 Apr 12, 2025
1dc4bda
update code
rxng8 Apr 12, 2025
24963e5
add clipping gradient to model utils
rxng8 Apr 12, 2025
e302af2
update reinforce cell to the new model utils clip
rxng8 Apr 12, 2025
fd36854
major cleanup in prep for merge over to main/prep for major release
Apr 12, 2025
1ee59b9
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
Apr 12, 2025
fa717cd
update test cases
rxng8 Apr 12, 2025
3a3ce8e
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn in…
Apr 12, 2025
2956763
update to require file in docs
Apr 12, 2025
26293af
Merge branch 'main' into major_release_update
ago109 Apr 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@

<img src="docs/images/ngc-learn-logo.png" width="300">

<b>ngc-learn</b> is a Python library for building, simulating, and analyzing
biomimetic systems, neurobiological agents, spiking neuronal networks,
predictive coding circuitry, and models that learn via biologically-plausible
forms of credit assignment. This simulation toolkit is built on top of JAX and is
distributed under the 3-Clause BSD license.
<b>ngc-learn</b> is a Python library for building, simulating, and analyzing biophysical / neurobiological systems, spiking neuronal networks, predictive coding circuitry, and biomimetic (NeuroAI) agents that learn in a biologically-plausible manner. This simulation toolkit, meant to support computational neuroscience and brain-inspired computing research, is built on top of JAX and is distributed under the 3-Clause BSD license.

It is currently maintained by the
<a href="https://www.cs.rit.edu/~ago/nac_lab.html">Neural Adaptive Computing (NAC) laboratory</a>.

## <b>Documentation</b>

Official documentation, including tutorials, can be found
<a href="https://ngc-learn.readthedocs.io/en/latest/#">here</a>. The model museum repo,
<a href="https://ngc-learn.readthedocs.io/en/latest/#">here</a>. The model museum repo (ngc-museum),
which implements several historical models, can be found
<a href="https://github.com/NACLab/ngc-museum">here</a>.

Expand All @@ -36,8 +32,8 @@ ngc-learn requires:
1) Python (>=3.10)
2) NumPy (>=1.26.0)
3) SciPy (>=1.7.0)
4) ngcsimlib (>=0.3.b4), (visit official page <a href="https://github.com/NACLab/ngc-sim-lib">here</a>)
5) JAX (>= 0.4.28) (to enable GPU use, make sure to install one of the CUDA variants)
4) ngcsimlib (>=1.0.0), (visit official page <a href="https://github.com/NACLab/ngc-sim-lib">here</a>)
5) JAX (>=0.4.28) (to enable GPU use, make sure to install one of the CUDA variants)
<!--
5) scikit-learn (>=1.3.1) if using `ngclearn.utils.density`
6) matplotlib (>=3.4.3) if using `ngclearn.utils.viz`
Expand All @@ -46,7 +42,7 @@ ngc-learn requires:
-->

---
ngc-learn 1.2.beta2 and later require Python 3.10 or newer as well as ngcsimlib >=0.3.b4.
ngc-learn 2.0.0 and later require Python 3.10 or newer as well as ngcsimlib >=1.0.0.
ngc-learn's plotting capabilities (routines within `ngclearn.utils.viz`) require
Matplotlib (>=3.8.0) and imageio (>=2.31.5) and both plotting and density estimation
tools (routines within ``ngclearn.utils.density``) will require Scikit-learn (>=0.24.2).
Expand All @@ -66,7 +62,7 @@ running the above pip command if you want to use the GPU version.

The documentation includes more detailed
<a href="https://ngc-learn.readthedocs.io/en/latest/installation.html">installation instructions</a>.
Note that this library was developed on Ubuntu 20.04 and tested on Ubuntu(s) 18.04 and 20.04.
Note that this library was developed on Ubuntu 20.04/22.04 and tested on Ubuntu(s) 20.04 and 22.04.

If the installation was successful, you should see the following if you test
it against your Python interpreter, i.e., run the <code>$ python</code> command
Expand Down
2 changes: 1 addition & 1 deletion docs/museum/snn_bfa.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ of $1000$ `SLIF` cells) similar to the one below:
<img src="../images/museum/bfa_snn/bfasnn_codes.jpg" width="450" />

Intriguingly, we see that the latent codes represented by the BFA-SNN's hidden
layer spikes yield a rather (piecewise) linearly-separable transformation
layer spikes yield a rather (piecewise) linearly-separable representation
of the input digits, making the process of mapping inputs to label vectors
much easier for the model's second layer of classification LIF units.
Note that, in the `BFA_SNN` model exhibit class, we estimated
Expand Down
6 changes: 3 additions & 3 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ numpy>=1.26.0
scikit-learn>=0.24.2
scipy>=1.7.0
matplotlib>=3.8.0
jax>=0.4.18
jaxlib>=0.4.18
jax>=0.4.28
jaxlib>=0.4.28
imageio>=2.31.5
ngcsimlib>=0.3.b4
ngcsimlib>=1.0.0
11 changes: 5 additions & 6 deletions docs/tutorials/model_basics/evolving_synapses.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ We do this specifically as follows:
```python
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
from ngclearn.components import HebbianSynapse, RateCell
import ngclearn.utils.weight_distribution as dist

Expand Down Expand Up @@ -49,16 +49,16 @@ with Context("Circuit") as circuit:
Wab.post << b.zF

## create and compile core simulation commands
evolve_process = (Process()
evolve_process = (JaxProcess()
>> a.evolve)
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")

advance_process = (Process()
advance_process = (JaxProcess()
>> a.advance_state)
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
>> a.reset)
reset_process = (JaxProcess()
>> a.reset)
circuit.wrap_and_add_command(jit(reset_process.pure), name="reset")

## set up non-compiled utility commands
Expand All @@ -83,7 +83,6 @@ for ts in range(x_seq.shape[1]):
circuit.advance(t=ts*1., dt=1.)
circuit.evolve(t=ts*1., dt=1.)
print(" {}: input = {} ~> Wab = {}".format(ts, x_t, Wab.weights.value))

```

Your code should produce the same output (towards the bottom):
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials/model_basics/model_building.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ While building our dynamical system we will set up a Context and then add the th
```python
from jax import numpy as jnp, random
from ngclearn import Context
from ngclearn.utils import JaxProcess
from ngcsimlib.compilers.process import Process
from ngclearn.components import RateCell, HebbianSynapse
import ngclearn.utils.weight_distribution as dist
Expand Down Expand Up @@ -71,13 +72,13 @@ This is simply done with the use of the following convenience function calls:

```python
## configure desired commands for simulation object
reset_process = (Process()
reset_process = (JaxProcess()
>> a.reset
>> Wab.reset
>> b.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

advance_process = (Process()
advance_process = (JaxProcess()
>> a.advance_state
>> Wab.advance_state
>> b.advance_state)
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/adex_cell.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import numpy as np

from ngclearn.utils.model_utils import scanner
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.adExCell import AdExCell

Expand All @@ -48,11 +48,11 @@ with Context("Model") as model:
)

## create and compile core simulation commands
advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/error_cell.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The code you would write amounts to the below:
```python
from jax import numpy as jnp, jit
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process, transition
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell

Expand All @@ -64,11 +64,11 @@ T = 5 ## number time steps to simulate
with Context("Model") as model:
cell = GaussianErrorCell("z0", n_units=3)

advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/fitzhugh_nagumo_cell.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ from jax import numpy as jnp, random, jit
import numpy as np

from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell

Expand All @@ -40,11 +40,11 @@ with Context("Model") as model:
gamma=gamma, v0=v0, w0=w0, integration_type="euler")

## create and compile core simulation commands
advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/neurocog/hebbian.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ Wab.post << b.zF
as well as (a bit later in the model construction code):

```python
evolve_process = (Process()
evolve_process = (JaxProcess()
>> a.evolve)
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")

advance_process = (Process()
advance_process = (JaxProcess()
>> a.advance_state)
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
```
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/hodgkin_huxley_cell.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import numpy as np

from ngclearn.utils.model_utils import scanner
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell

Expand Down Expand Up @@ -52,11 +52,11 @@ with Context("Model") as model:
)

## create and compile core simulation commands
advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/input_cells.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ spike train over $100$ steps in time as follows:
```python
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess

from ngclearn.utils.viz.raster import create_raster_plot
## import model-specific mechanisms
Expand All @@ -56,11 +56,11 @@ T = 100 ## number time steps to simulate
with Context("Model") as model:
cell = BernoulliCell("z0", n_units=10, key=subkeys[0])

advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/izhikevich_cell.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from jax import numpy as jnp, random, jit
import numpy as np

from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell

Expand All @@ -44,11 +44,11 @@ with Context("Model") as model:
integration_type="euler", v0=v0, w0=w0, key=subkeys[0])

## create and compile core simulation commands
advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/lif.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ cell, you would write code akin to the following:
from jax import numpy as jnp, random, jit

from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
from ngclearn.utils.viz.spike_plot import plot_spiking_neuron
Expand All @@ -47,11 +47,11 @@ with Context("Model") as model:
refract_time=2., key=subkeys[0])

## create and compile core simulation commands
advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/neurocog/mod_stdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ and the required compiled simulation and dynamic commands, can be done as follow
```python
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse,
RewardErrorCell, VarTrace)
Expand Down Expand Up @@ -75,13 +75,13 @@ with Context("Model") as model:
tr1 = VarTrace("tr1", n_units=1, tau_tr=tau_post, a_delta=Aminus)
rpe = RewardErrorCell("r", n_units=1, alpha=0.)

evolve_process = (Process()
evolve_process = (JaxProcess()
>> W_stdp.evolve
>> W_mstdp.evolve
>> W_mstdpet.evolve)
model.wrap_and_add_command(jit(evolve_process.pure), name="evolve")

advance_process = (Process()
advance_process = (JaxProcess()
>> tr0.advance_state
>> tr1.advance_state
>> rpe.advance_state
Expand All @@ -90,7 +90,7 @@ with Context("Model") as model:
>> W_mstdpet.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> W_stdp.reset
>> W_mstdp.reset
>> W_mstdpet.reset
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/rate_cell.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ specifically the rate-cell (RateCell). Let's start with the file's header

```python
from jax import numpy as jnp, random, jit
from ngcsimlib.compilers.process import Process, transition
from ngclearn.utils import JaxProcess
from ngcsimlib.context import Context
## import model-specific elements
from ngclearn.components.neurons.graded.rateCell import RateCell
Expand All @@ -40,11 +40,11 @@ with Context("Model") as model: ## model/simulation definition
prior=("gaussian", gamma), integration_type="euler", key=subkeys[0])

## instantiate desired core commands that drive the simulation
advance_process = (Process()
advance_process = (JaxProcess()
>> cell.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> cell.reset)
model.wrap_and_add_command(jit(reset_process.pure), name="reset")

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/neurocog/short_term_plasticity.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ STF-dominated dynamics):
```python
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
from ngcsimlib.compilers.process import Process
from ngclearn.utils import JaxProcess
## import model-specific mechanisms
from ngclearn.components import PoissonCell, STPDenseSynapse, LIFCell
import ngclearn.utils.weight_distribution as dist
Expand Down Expand Up @@ -98,13 +98,13 @@ with Context("Model") as model:
W.inputs << z0.outputs ## z0 -> W
z1.j << W.outputs ## W -> z1

advance_process = (Process()
advance_process = (JaxProcess()
>> z0.advance_state
>> W.advance_state
>> z1.advance_state)
model.wrap_and_add_command(jit(advance_process.pure), name="advance")

reset_process = (Process()
reset_process = (JaxProcess()
>> z0.reset
>> z1.reset
>> W.reset)
Expand Down
Loading
Loading