New Pytree Backend#646
Conversation
|
Things we agreed upon:
There are two options for the central backend. Option 1: import jax.numpy as jnp
edges = {
(0, 1): {"radius": jnp.asarray([1.0, 1.0])},
(1, 2): {"radius": jnp.asarray([1.0, 1.0])},
(2, 3): {},
}
nodes = {
0: {"radius": jnp.ndarray([0.0])},
1: {"radius": jnp.ndarray([0.0]), "length": jnp.ndarray([1.0])},
2: {"neural_net": jnp.pytree},
3: {"gNa": jnp.asarray([0, 1])},
}
# Arguments pro:
# - nodes[3]["radius"] = 1.0
# Arguments against making this the _central structure_:
# - we are not sure if vmap across stim-compartments works.
# - generally vmap is more difficult and could (maybe) be done via tree_utils.
# - we would have to build the structure below anyways.
# - this is a larger refactor.Option 2: nodes = {
"index": jnp.ndarray((4,)),
"radius": jnp.ndarray((4,)),
"gNa": jnp.ndarray((4, 2)),
"neural_net": Dict([str, jnp.ndarray(4, 3)]),
}
indices = nodes_pd.query(...).index
nodes[nodes["index"] == indices]["radius"] = 1.0
edges = {
"pre": jnp.ndarray([0, 1, 2]),
"post": jnp.ndarray([0, 1, 2]),
"radius": jnp.ndarray([0, 1, NaN]),
}A nice and desirable feature would be to vmap across stimuli, e.g.: # vmap across stimulus locations.
stim_comps = [0, 1, 2]
def run(stim_comp):
cell.comp(stim_comp).stimulate(...)
return jx.integrate(cell)
vmapped_run = vmap(run)
vmapped_run(stim_comps)This is, in principle, supported for JAX arrays: def simple_test(index):
new_array = my_array.at[index].set(3.0)
return jnp.sum(new_array ** 2)
vmapped_test = vmap(simple_test)
indices = jnp.asarray([0, 1, 2])
my_array = jnp.asarray([3.2, 4.3, 5.8]) |
|
I think the hard problem is solving the following: Suppose you insert a channel into only a few compartments. Then there are essentially two variants: Version 1, where we pad with NaNs nodes = {"idx": jnp.array([0,1,2,3, ...])}
# insert channel into comp 0
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1, NaN, NaN, NaN, ...])
}
# insert channel into comp 2
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
}Version 2, where we make it sparse, but instead need to keep track of what was inserted where etc (which is tricky with shared params, i.e. what if two different channels both share calcium_conc as a state, but one is inserted into [0,1] and the other into [1,2] ). nodes = {"idx": jnp.array([0,1,2,3, ...])}
# insert channel into comp 0
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1, ...]),
"Na": jnp.array([True, False, False, False]),
}
# insert channel into comp 2
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1,0.1, ...]),
"Na": jnp.array([True, False, True, False]),
}I have so far not been able to come up with a good way to do the below. Hence, despite the potential memory inefficiency, I prefer version 1 since it is easier to work with and understand. Also jax.experimental.sparse is something we could think about. Assuming we are not going to find a good way to do 2) and have to pad using NaNs, I see 2 ways to do this:
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
}and carry around all NaNs.
I slightly prefer 1., since we need the NaNs anyways. Inserting for example via I hope the above makes some sense. Lemme know your thoughts. |
|
This all makes sense, and I agree with your judgement on what the best solution is: nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
} |
|
I spent some more time thinking about this and sth. like the following should also work, since we still have to loop over channels / synapses anyways . morph = {
"nodes": {
"idx": jnp.array([0, 1, 2, 3]),
"params": {
"radius": jnp.array([1.0, 1.0, 1.0, 1.0]),
"capacitance": jnp.array([1.0, 1.0, 1.0, 1.0]),
},
"states": {
"v": jnp.array([0.0, 0.0, 0.0, 0.0]),
},
#...
},
"edges": {
"idx": jnp.array([[0, 0], [1, 2]]),
#...
},
}
mechanisms = {
"channels": {
"Na1": {
"Na1_idx": jnp.array([0, 2]),
"params": {
"Na1_g": jnp.array([0.1, 0.1]),
"shared_param": jnp.array([0.1, 0.1]),
},
"states": {
"Na1_m": jnp.array([0.1, 0.1]),
"Na1_h": jnp.array([0.1, 0.1]),
},
},
"Na2": {
"Na2_idx": jnp.array([0, 1]),
"params": {
"Na2_g": jnp.array([0.1, 0.1]),
"shared_param": jnp.array([0.1, 0.1]),
},
"states": {
"Na2_m": jnp.array([0.1, 0.1]),
"Na2_h": jnp.array([0.1, 0.1]),
},
},
},
"synapses": {
"Ionotropic": {
"Ionotropic_idx": jnp.array([[0, 0]]),
"params": {
"Ionotropic_g": jnp.array([0.1]),
},
"states": {
"Ionotropic_s": jnp.array([0.1]),
},
},
},
}Lines 2448 to 2474 in c6d11cd Could be modified to sth. like: for name, channel_data in self.pytree["channels"].items():
channel_voltages = voltages[channel_data[f"{name}_idx"]]
states_updated = self.channels[name].update_states(channel_data["states"], delta_t, channel_voltages, channel_data["params"])I see 2 problems with this approach though.
mechanisms = {
"channels": {
"Na1": {
"Na1_idx": jnp.array([0, 2]),
"params": {
"Na1_g": jnp.array([0.1, 0.1]),
},
"states": {
"Na1_m": jnp.array([0.1, 0.1]),
"Na1_h": jnp.array([0.1, 0.1]),
},
},
"Na2": {
"Na2_idx": jnp.array([0, 1]),
"params": {
"Na2_g": jnp.array([0.1, 0.1]),
},
"states": {
"Na2_m": jnp.array([0.1, 0.1]),
"Na2_h": jnp.array([0.1, 0.1]),
},
},
"__shared__": {
"params": {
"shared_param": jnp.array([0.1, 0.1, 0.1]),
},
"states": {
"shared_state": jnp.array([0.1, 0.1, 0.1]),
},
},
},
#...
}This somewhat mimics the current cell.comp([0,1]).insert(Na())
cell.comp([0]).set("Na_g", 0.2)
cell.channels["Na"].params["Na_g"] # -> jnp.array([0.1, 0.2])(assuming cell.channels is a dict and not a list how we have it currently). Again, the only thing I am unsure about is param / state sharing. Looking forward to hear your thoughts on this. EDIT: we can also remove the "states" and "params" level, but then the looping would not be as nice, since |
#########################################
#########################################
#########################################
#########################################
mydict = {
0: 3 * jnp.ones((5,)),
1: 4 * jnp.ones((5,)),
2: 5 * jnp.ones((5,)),
}
def stimulate(index):
return mydict[index].at[2].add(2.0)
vmapped_stim = jax.vmap(stimulate)
# Test
stimulate(jnp.asarray(2))
arr = jnp.asarray([0, 1, 2])
vmapped_stim(arr) # Fails
#########################################
#########################################
#########################################
#########################################
mydict = {
"inds": jnp.asarray([0, 1, 2]),
"vals": jnp.asarray([3.0, 4.0, 5.0]),
}
def stimulate(index):
my_ind = mydict["inds"][index]
return mydict["vals"][my_ind] + 2.0
vmapped_stim = jax.vmap(stimulate)
# Test
stimulate(jnp.asarray(2))
arr = jnp.asarray([0, 1, 2])
vmapped_stim(arr) |
df = net.edges # this is a `property` that automatically builds the df.
df = df.query("pre_global_cell_index in [3, 4]")
net.select(edges=df.index).set("Ionotropic_gS", 0.2)This is a bit annoying because it now requires the user to understand that there is a difference between filtered_node_inds = [net.cell(len(synapse_locations)).branch(b).loc(l).jax_nodes.index for b, l in zip(branches, locs)]
post = net.select(nodes=filtered_node_inds)This is a bit annoying because it now requires the user to understand that there is a difference between filtered_node_inds = [net.cell(len(synapse_locations)).branch(b).loc(l).jax_nodes.index for b, l in zip(branches, locs)]
def simulate(vals):
return cell.select(nodes= filtered_node_inds ).set("length", vals) |
|
meeting summary: We decided on the following structure: nodes = {"idx": jnp.array([0,1,2,3, ...])}
# insert channel into comp 0
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1, NaN, NaN, NaN, ...])
}
# insert channel into comp 2
nodes = {"idx": jnp.array([0,1,2,3, ...]),
"gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
}We keep Viewing will work on the backend, pandas will only be used to render the pytree. |
This PR attempts to move jaxley to a more flexible backend.
Goals:
to_jaxbefore simulationMy first idea was to create a custom pytree that holds
from which one can easily go back and forth between pandas and networkx if desired and which is also broadly in line with the current pandas way of doing things. See example below:
Pros:
Things that are not yet solved in the above example:
jax.tree_util.tree_map(lambda *args: jnp.array(args), *node_attrs.values())would be efficient and fast, it only works if the structure of all the inner trees (i.e. each comp) is the same. Inserting a channel only in one compartment would mean either inserting NaNs everywhere else or doing the flattening differently / manually (which works, but would mean having ato_jaxequiv.). One could think about using https://docs.jax.dev/en/latest/jax.experimental.sparse.html potentially.Thoughts I also had:
[{"Synapse1": True, "Synapse2": True, "Synapse1_gS": 1e-3, "Synapse2_gS": 1e-3}]vs.[{"Synapse": Synapse1, "Synapse1_gS": 1e-3, "Synapse2_gS": NaN}, {"Synapse": Synapse2, "Synapse1_gS": NaN, "Synapse2_gS": 1e-3}]Addresses #557, #555
Other related issues: #644, #632