@@ -23,17 +23,48 @@ def build_init_and_step_fn(
2323 """Return ``init_fn`` and ``step_fn`` which initialize modules and run update steps.
2424
2525 This method can be used to gain additional control over the simulation workflow.
26- It exposes the `` step`` function, which can be used to perform step-by-step updates
26+ It exposes a step function, which can be used to perform step-by-step updates
2727 of the differential equations.
2828
2929 Args:
30- module: A `Module` object that e.g. a cell .
31- voltage_solver: Voltage solver used in step. Defaults to "jaxley.stone ".
30+ module: A ``jx. Module`` object that, for example a ``jx.Cell`` .
31+ voltage_solver: Voltage solver used in step. Defaults to "jaxley.dhs ".
3232 solver: ODE solver. Defaults to "bwd_euler".
3333
3434 Returns:
35- init_fn, step_fn: Functions that initialize the state and parameters, and
36- perform a single integration step, respectively.
35+
36+ * ``init_fn(params, all_states=None, param_state=None, delta_t=0.025)``
37+
38+ Callable which initializes the states and parameters.
39+
40+ * Args:
41+
42+ * ``params`` (list[dict]): returned by `.get_parameters()`.
43+ * ``all_states`` (dict | None = None): typically `None`.
44+ * ``param_state`` (list[dict] | None = None): returned by `.data_set()`.
45+ * ``delta_t`` (float = 0.025): the time step.
46+
47+ * Returns:
48+
49+ * ``all_states`` (dict).
50+ * ``all_params`` (dict), which can be passed to the `step_fn`.
51+
52+ * ``step_fn(all_states, all_params, external_inds, externals, delta_t=0.025)``
53+
54+ Callable which performs a single integration step.
55+
56+ * Args:
57+
58+ * ``all_states`` (dict): returned by `init_fn()`.
59+ * ``all_params`` (dict): returned by `init_fn()`.
60+ * ``externals`` (dict): obtained with `module.externals.copy()` but using
61+ only the external input at the current time step (see examples below).
62+ * ``external_inds`` (dict): obtained with `module.external_inds.copy()`.
63+ * ``delta_t`` (float): the time step.
64+
65+ * Returns:
66+
67+ * Updated ``all_states`` (dict).
3768
3869 Example usage
3970 ^^^^^^^^^^^^^
0 commit comments