Skip to content

Commit 43b0fed

Browse files
proposed new api; not tested (#731)
* proposed new api; not tested --------- Co-authored-by: Matthijspals <matthijs-pals@hotmail.com>
1 parent 45aee89 commit 43b0fed

File tree

6 files changed

+354
-407
lines changed

6 files changed

+354
-407
lines changed

docs/jaxley.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Simulation
2323

2424
jaxley.integrate
2525
jaxley.integrate.build_init_and_step_fn
26-
jaxley.utils.dynamics.build_step_dynamics_fn
26+
jaxley.utils.dynamics.build_dynamic_state_utils
2727

2828

2929
Morphologies

jaxley/integrate.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,48 @@ def build_init_and_step_fn(
2323
"""Return ``init_fn`` and ``step_fn`` which initialize modules and run update steps.
2424
2525
This method can be used to gain additional control over the simulation workflow.
26-
It exposes the ``step`` function, which can be used to perform step-by-step updates
26+
It exposes a step function, which can be used to perform step-by-step updates
2727
of the differential equations.
2828
2929
Args:
30-
module: A `Module` object that e.g. a cell.
31-
voltage_solver: Voltage solver used in step. Defaults to "jaxley.stone".
30+
module: A ``jx.Module`` object that, for example a ``jx.Cell``.
31+
voltage_solver: Voltage solver used in step. Defaults to "jaxley.dhs".
3232
solver: ODE solver. Defaults to "bwd_euler".
3333
3434
Returns:
35-
init_fn, step_fn: Functions that initialize the state and parameters, and
36-
perform a single integration step, respectively.
35+
36+
* ``init_fn(params, all_states=None, param_state=None, delta_t=0.025)``
37+
38+
Callable which initializes the states and parameters.
39+
40+
* Args:
41+
42+
* ``params`` (list[dict]): returned by `.get_parameters()`.
43+
* ``all_states`` (dict | None = None): typically `None`.
44+
* ``param_state`` (list[dict] | None = None): returned by `.data_set()`.
45+
* ``delta_t`` (float = 0.025): the time step.
46+
47+
* Returns:
48+
49+
* ``all_states`` (dict).
50+
* ``all_params`` (dict), which can be passed to the `step_fn`.
51+
52+
* ``step_fn(all_states, all_params, external_inds, externals, delta_t=0.025)``
53+
54+
Callable which performs a single integration step.
55+
56+
* Args:
57+
58+
* ``all_states`` (dict): returned by `init_fn()`.
59+
* ``all_params`` (dict): returned by `init_fn()`.
60+
* ``externals`` (dict): obtained with `module.externals.copy()` but using
61+
only the external input at the current time step (see examples below).
62+
* ``external_inds`` (dict): obtained with `module.external_inds.copy()`.
63+
* ``delta_t`` (float): the time step.
64+
65+
* Returns:
66+
67+
* Updated ``all_states`` (dict).
3768
3869
Example usage
3970
^^^^^^^^^^^^^

jaxley/modules/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def __init__(self):
175175
# List of types of all `jx.Channel`s.
176176
self.channels: List[Channel] = []
177177
self.membrane_current_names: List[str] = []
178+
self.synapse_current_names: List[str] = []
178179

179180
# List of all pumps.
180181
self.pumped_ions: List[str] = []

jaxley/modules/network.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ def _synapse_currents(
336336
) -> tuple[dict, tuple[Array, Array]]:
337337
voltages = states["v"]
338338

339+
current_states = {}
340+
for name in self.synapse_current_names:
341+
current_states[name] = jnp.zeros_like(voltages)
342+
339343
grouped_syns = edges.groupby("type", sort=False, group_keys=False)
340344
pre_syn_inds = grouped_syns["pre_index"].apply(list)
341345
post_syn_inds = grouped_syns["post_index"].apply(list)
@@ -617,6 +621,10 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
617621
self.base.edges["controlled_by_param"] = 0
618622
self._edges_in_view = self.edges.index.to_numpy()
619623

624+
current_name = f"i_{synapse_type._name}"
625+
if current_name not in self.base.synapse_current_names:
626+
self.base.synapse_current_names.append(current_name)
627+
620628
def _add_params_to_edges(self, synapse_type, indices):
621629
# Add parameters and states to the `.edges` table.
622630
for key, param_val in synapse_type.synapse_params.items():

0 commit comments

Comments
 (0)