Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# 0.11.6 (pre-release)

### 🧩 New features

- add step function that allows stepping through a simulation with a vector-valued state function (#719 @matthijspals)

### 🛠️ Internal updates

- separate getting the currents from `get_all_states()` (#727, @michaeldeistler). To
Expand Down
38 changes: 17 additions & 21 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

# -- Project information -----------------------------------------------------

project = 'Jaxley'
copyright = '2024, Jaxleyverse team'
author = 'Jaxleyverse team'
project = "Jaxley"
copyright = "2024, Jaxleyverse team"
author = "Jaxleyverse team"


# -- General configuration ---------------------------------------------------
Expand All @@ -45,19 +45,15 @@
"jax": ("https://jax.readthedocs.io/en/latest", None),
}

source_suffix = {
'.rst': 'restructuredtext',
'.myst': 'myst-nb',
'.ipynb': 'myst-nb'
}
source_suffix = {".rst": "restructuredtext", ".myst": "myst-nb", ".ipynb": "myst-nb"}

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# Myst-NB
myst_enable_extensions = [
Expand All @@ -76,16 +72,16 @@
#
html_title = ""
html_logo = "logo.png"
html_theme = 'sphinx_book_theme'
html_theme = "sphinx_book_theme"
html_theme_options = {
'repository_url': 'https://github.com/jaxleyverse/jaxley',
"repository_url": "https://github.com/jaxleyverse/jaxley",
"use_repository_button": True,
"use_download_button": False,
'repository_branch': 'main',
"path_to_docs": 'docs',
'launch_buttons': {
'colab_url': 'https://colab.research.google.com',
'binderhub_url': 'https://mybinder.org'
"repository_branch": "main",
"path_to_docs": "docs",
"launch_buttons": {
"colab_url": "https://colab.research.google.com",
"binderhub_url": "https://mybinder.org",
},
"toc_title": "Navigation",
"show_navbar_depth": 1,
Expand All @@ -96,8 +92,8 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_css_files = ['custom.css']
html_static_path = ["_static"]
html_css_files = ["custom.css"]

autosummary_generate = True
autodoc_typehints = "description"
Expand All @@ -107,5 +103,5 @@
"members": True,
"undoc-members": True,
"inherited-members": True,
"show-inheritance": True
}
"show-inheritance": True,
}
1 change: 1 addition & 0 deletions docs/jaxley.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Simulation

jaxley.integrate
jaxley.integrate.build_init_and_step_fn
jaxley.utils.dynamics.build_dynamic_state_utils


Morphologies
Expand Down
51 changes: 42 additions & 9 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,48 @@ def build_init_and_step_fn(
"""Return ``init_fn`` and ``step_fn`` which initialize modules and run update steps.

This method can be used to gain additional control over the simulation workflow.
It exposes the ``step`` function, which can be used to perform step-by-step updates
It exposes a step function, which can be used to perform step-by-step updates
of the differential equations.

Args:
module: A `Module` object that e.g. a cell.
voltage_solver: Voltage solver used in step. Defaults to "jaxley.stone".
module: A ``jx.Module`` object that, for example a ``jx.Cell``.
voltage_solver: Voltage solver used in step. Defaults to "jaxley.dhs".
solver: ODE solver. Defaults to "bwd_euler".

Returns:
init_fn, step_fn: Functions that initialize the state and parameters, and
perform a single integration step, respectively.

* ``init_fn(params, all_states=None, param_state=None, delta_t=0.025)``

Callable which initializes the states and parameters.

* Args:

* ``params`` (list[dict]): returned by `.get_parameters()`.
* ``all_states`` (dict | None = None): typically `None`.
* ``param_state`` (list[dict] | None = None): returned by `.data_set()`.
* ``delta_t`` (float = 0.025): the time step.

* Returns:

* ``all_states`` (dict).
* ``all_params`` (dict), which can be passed to the `step_fn`.

* ``step_fn(all_states, all_params, external_inds, externals, delta_t=0.025)``

Callable which performs a single integration step.

* Args:

* ``all_states`` (dict): returned by `init_fn()`.
* ``all_params`` (dict): returned by `init_fn()`.
* ``externals`` (dict): obtained with `module.externals.copy()` but using
only the external input at the current time step (see examples below).
* ``external_inds`` (dict): obtained with `module.external_inds.copy()`.
* ``delta_t`` (float): the time step.

* Returns:

* Updated ``all_states`` (dict).

Example usage
^^^^^^^^^^^^^
Expand Down Expand Up @@ -70,10 +101,12 @@ def build_init_and_step_fn(
# Initialize.
init_fn, step_fn = build_init_and_step_fn(cell)
states, params = init_fn(params)
recordings = [
states[rec_state][rec_ind][None]
for rec_state, rec_ind in zip(rec_states, rec_inds)
]
recordings = [jnp.asarray(
[
all_states[rec_state][rec_ind]
for rec_state, rec_ind in zip(rec_states, rec_inds)
]
)]

# Loop over the ODE. The `step_fn` can be jitted for improving speed.
steps = int(t_max / delta_t) # Steps to integrate
Expand Down
1 change: 1 addition & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(self):
# List of types of all `jx.Channel`s.
self.channels: List[Channel] = []
self.membrane_current_names: List[str] = []
self.synapse_current_names: List[str] = []

# List of all pumps.
self.pumped_ions: List[str] = []
Expand Down
8 changes: 8 additions & 0 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def _synapse_currents(
) -> tuple[dict, tuple[Array, Array]]:
voltages = states["v"]

current_states = {}
for name in self.synapse_current_names:
current_states[name] = jnp.zeros_like(voltages)

grouped_syns = edges.groupby("type", sort=False, group_keys=False)
pre_syn_inds = grouped_syns["pre_index"].apply(list)
post_syn_inds = grouped_syns["post_index"].apply(list)
Expand Down Expand Up @@ -617,6 +621,10 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
self.base.edges["controlled_by_param"] = 0
self._edges_in_view = self.edges.index.to_numpy()

current_name = f"i_{synapse_type._name}"
if current_name not in self.base.synapse_current_names:
self.base.synapse_current_names.append(current_name)

def _add_params_to_edges(self, synapse_type, indices):
# Add parameters and states to the `.edges` table.
for key, param_val in synapse_type.synapse_params.items():
Expand Down
Loading
Loading