Skip to content

Conversation

@michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Oct 23, 2025

Hi, here is a first attempt at addressing issue #717, having a function F that we can use to step through the model as x_{t+1}=F(x_t), where x_t is a vector of states, without observables.

This is largely a modified version of build_init_and_step_dynamics_fn but the init function now also returns two functions: one for going from the full state pytree to a vector of only "true" states, and one for going back from this vector to the full state including observables. It was a bit of a pain to get it working with jit, but it seems to be running now. I haven't tested yet whether or not it works nicely with gradients (edit: seems to be fine now).

  • I am wondering whether we need a separate init function? maybe it is easier if the main function that returns the step function already initialises the model (edit: removed init function).

  • The way inputs are handled here seems a bit clunky, but I don't have a good idea of how to make it better yet. (edit @bantin's solution from Working directly with ODE dynamics function #715 seems nicer; after discussing we will keep as it is for now)

  • Note, I still have to add tests, and update the log (edit: added!)

The main functionality is there and running:

import jaxley as jx
import jaxley.optimize.transforms as jt
from jaxley.channels import Leak
from jaxley.channels.hh import HH
from jaxley.integrate import add_stimuli
from jaxley.utils.dynamics import build_step_dynamics_fn
import jax.numpy as jnp
import numpy as np
import optax
from jax import jit, value_and_grad


# Build a simple cell
ncomp_per_branch = 8
comp = jx.Compartment()
branch = jx.Branch(comp, ncomp_per_branch)
cell = jx.Cell(branch, parents=[-1, 0, 0])
cell.insert(HH())
cell.insert(Leak())

# make some parameters trainable
cell.make_trainable("Leak_gLeak")
cell.make_trainable("v")
params = cell.get_parameters()
# Define parameter transform and apply it to the parameters.
transform = jx.ParamTransform(
    [
        {"Leak_gLeak": jt.SigmoidTransform(0.00001, 0.0002)},
        {"v": jt.SigmoidTransform(-100, -30)},
    ]
)
opt_params = transform.inverse(params)
params = transform.forward(opt_params)
cell.to_jax()

# Test jit and training
# ----------------------------------------------

# add some inputs
externals = cell.externals.copy()
external_inds = cell.external_inds.copy()
current = jx.step_current(
    i_delay=1.0, i_dur=2.0, i_amp=0.08, delta_t=0.025, t_max=0.075
)
data_stimuli = None
data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)
externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)

def get_externals_now(externals, step):
    externals_now = {}
    for key in externals.keys():
        externals_now[key] = externals[key][:, step]
    return externals_now

target_voltage = -60.0
state_idx = -30

# Define the optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(opt_params)

def loss(opt_params):
    params = transform.forward(opt_params)

    # initialise and build the step function
    states_vec, step_dynamics_fn, _, _ = build_step_dynamics_fn(
        cell, solver="bwd_euler", delta_t=0.025, params=params
    )

    # JIT the step function for speed
    @jit
    def step_fn_vec_to_vec(states_vec, externals_now, params=None):
        states_vec = step_dynamics_fn(
            states_vec,
            params,
            externals=externals_now,
            external_inds=external_inds,
            delta_t=0.025,
        )
        return states_vec

    states_vecs = [states_vec]

    # Simulate the model
    for step in range(3):
        # Get inputs at this time step
        externals_now = get_externals_now(externals, step)
        # Step the ODE
        states_vec = step_fn_vec_to_vec(states_vec, externals_now, params)
        # Store the state
        states_vecs.append(states_vec)
    # Compute the loss at the last time step
    loss = jnp.mean((states_vecs[-1][state_idx] - target_voltage) ** 2)
    return loss

# Compute the gradient of the loss with respect to the parameters
grad_loss = value_and_grad(loss, argnums=0)
value, gradient = grad_loss(opt_params)

updates, opt_state = optimizer.update(gradient, opt_state)

@michaeldeistler michaeldeistler force-pushed the matthijs2 branch 2 times, most recently from beee15f to 6094102 Compare October 23, 2025 08:44
@michaeldeistler michaeldeistler changed the title Matthijs2 Dynamics function which uses jnp.ndarray (not pytrees) Oct 23, 2025
@jaxleyverse jaxleyverse deleted a comment from Matthijspals Oct 23, 2025
@Matthijspals
Copy link
Contributor

Matthijspals commented Oct 23, 2025

Should be done!

It's nice how straightforward it is to get Jacobians:

from jaxley.utils.dynamics import build_step_dynamics_fn
import jaxley as jx
from jaxley.channels import Leak
from jax import jacfwd
import matplotlib.pyplot as plt
%matplotlib inline

# Build a simple cell
comp = jx.Compartment()
branch = jx.Branch(comp, 8)
cell = jx.Cell(branch, parents=[-1, 0, 0])
cell.insert(Leak())
cell.to_jax()

# Obtain the step function
states_vec, step_dynamics_fn, states_to_pytree, states_to_full_pytree, full_pytree_to_states = build_step_dynamics_fn(
        cell, solver="bwd_euler", delta_t=0.025
    )

# Obtain and plot Jacobian
jacobian = jacfwd(step_dynamics_fn)(states_vec)
plt.imshow(jacobian, cmap ='viridis')
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants