Skip to content

New Pytree Backend#646

Open
jnsbck wants to merge 24 commits into
mainfrom
graph_backend
Open

New Pytree Backend#646
jnsbck wants to merge 24 commits into
mainfrom
graph_backend

Conversation

@jnsbck
Copy link
Copy Markdown
Contributor

@jnsbck jnsbck commented Jun 4, 2025

This PR attempts to move jaxley to a more flexible backend.

Goals:

  • allow parameters and states to be arrays
  • allow vmap over stimuli
  • ideally avoid calling to_jax before simulation
  • easy conversion / interoperability with networkx and pandas

My first idea was to create a custom pytree that holds

node_attrs = {0: {"l":0.1, "x": 0.0, "y": 0.0, "z": 0.0, "r": 0.0}, ...}
edge_attrs = {(0,1): {"IonotropicSynapse_gS: 0.001}, ...}
global_attrs = {"xyzr": [....], ...}

from which one can easily go back and forth between pandas and networkx if desired and which is also broadly in line with the current pandas way of doing things. See example below:

@jax.tree_util.register_dataclass
@dataclass
class MorphTree:
    node_attrs: Dict[int, Dict[str, Any]]
    edge_attrs: Dict[Tuple[int, int], Dict[str, Any]]
    global_attrs: Dict[str, Any] = field(default_factory=dict)

    @property
    def nodes(self) -> jnp.ndarray:
        """Returns the node indices as a jax array."""
        return jnp.array(list(self.node_attrs.keys())).astype(int)

    @property
    def edges(self) -> jnp.ndarray:
        """Returns the edge indices as a jax array."""
        return jnp.array(list(self.edge_attrs.keys())).astype(int)

    def __repr__(self) -> str:
        n_nodes = len(self.node_attrs)
        n_edges = len(self.edge_attrs)

        node_keys = list(next(iter(self.node_attrs.values())).keys())
        if len(self.edge_attrs) > 0:
            edge_keys = list(next(iter(self.edge_attrs.values())).keys())
        else:
            edge_keys = []

        node_attrs = node_keys if len(self.node_attrs) > 0 else []
        edge_attrs = edge_keys if len(self.edge_attrs) > 0 else []
        return f"MorphTree(nodes={n_nodes}*{node_attrs}, edges={n_edges}*{edge_attrs}, global={list(self.global_attrs.keys())})"
    
    def node(self, i: int) -> Dict[str, Any]:
        """Returns the node attributes for the node with index i."""
        return self.node_attrs[i]
    
    def edge(self, i: int, j: int) -> Dict[str, Any]:
        """Returns the edge attributes for the edge between nodes i and j."""
        return self.edge_attrs[i, j]
    
    def to_nx(self) -> nx.DiGraph:
        """Returns the MorphTree as a networkx DiGraph."""
        G = nx.DiGraph()
        G.add_nodes_from(self.node_attrs.items())
        G.add_edges_from((i, j, d) for (i, j), d in self.edge_attrs.items())
        G.graph.update(self.global_attrs)
        return G
    
    @staticmethod
    def from_nx(G: nx.DiGraph) -> MorphTree:
        """Returns a MorphTree from a networkx DiGraph."""
        node_attrs = {n: G.nodes[n] for n in G.nodes}
        edge_attrs = {(i, j): G.edges[i, j] for i, j in G.edges}
        return MorphTree(node_attrs, edge_attrs, G.graph)
    
    def to_pandas(self, return_global_attrs: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Returns the MorphTree as a pandas DataFrame."""
        node_df = pd.DataFrame(self.node_attrs.values(), index=self.node_attrs.keys())
        edge_df = pd.DataFrame(self.edge_attrs.values(), index=self.edge_attrs.keys())
        edge_index = pd.MultiIndex.from_arrays(np.array(self.edges).T)
        edge_df = edge_df.set_index(edge_index)

        if return_global_attrs:
            return node_df, edge_df, pd.Series(self.global_attrs)
        return node_df, edge_df

Pros:

  • easy to manipulate
  • sparse / memory efficient
  • Relatively simple

Things that are not yet solved in the above example:

  • this would need to be transposed to vmap over. While the following: jax.tree_util.tree_map(lambda *args: jnp.array(args), *node_attrs.values()) would be efficient and fast, it only works if the structure of all the inner trees (i.e. each comp) is the same. Inserting a channel only in one compartment would mean either inserting NaNs everywhere else or doing the flattening differently / manually (which works, but would mean having a to_jax equiv.). One could think about using https://docs.jax.dev/en/latest/jax.experimental.sparse.html potentially.

Thoughts I also had:

  • Move synapse_edges indexing to tuples (corresponding to nodes/comps that are connected by a synapse)
  • store multiple synapses per row in edges -> make it less sparse, i.e. [{"Synapse1": True, "Synapse2": True, "Synapse1_gS": 1e-3, "Synapse2_gS": 1e-3}] vs. [{"Synapse": Synapse1, "Synapse1_gS": 1e-3, "Synapse2_gS": NaN}, {"Synapse": Synapse2, "Synapse1_gS": NaN, "Synapse2_gS": 1e-3}]
  • global_attrs could contain global params and states see A global state for Jaxley modules #476 (we should make sure they can also be not just floats)

Addresses #557, #555
Other related issues: #644, #632

@michaeldeistler
Copy link
Copy Markdown
Contributor

michaeldeistler commented Jun 4, 2025

Things we agreed upon:

  • indexing and traversal will be done via pandas. The pd.DataFrame will always be an attribute. It will not be built on the fly.

There are two options for the central backend. Option 1:

import jax.numpy as jnp

edges = {
    (0, 1): {"radius": jnp.asarray([1.0, 1.0])},
    (1, 2): {"radius": jnp.asarray([1.0, 1.0])},
    (2, 3): {},
}
nodes = {
    0: {"radius": jnp.ndarray([0.0])},
    1: {"radius": jnp.ndarray([0.0]), "length": jnp.ndarray([1.0])},
    2: {"neural_net": jnp.pytree},
    3: {"gNa": jnp.asarray([0, 1])},
}

# Arguments pro:
# - nodes[3]["radius"] = 1.0

# Arguments against making this the _central structure_:
# - we are not sure if vmap across stim-compartments works.
# - generally vmap is more difficult and could (maybe) be done via tree_utils.
# - we would have to build the structure below anyways.
# - this is a larger refactor.

Option 2:

nodes = {
    "index": jnp.ndarray((4,)),
    "radius": jnp.ndarray((4,)),
    "gNa": jnp.ndarray((4, 2)),
    "neural_net": Dict([str, jnp.ndarray(4, 3)]),
}
indices = nodes_pd.query(...).index
nodes[nodes["index"] == indices]["radius"] = 1.0

edges = {
    "pre": jnp.ndarray([0, 1, 2]),
    "post": jnp.ndarray([0, 1, 2]),
    "radius": jnp.ndarray([0, 1, NaN]),
}

A nice and desirable feature would be to vmap across stimuli, e.g.:

# vmap across stimulus locations.
stim_comps = [0, 1, 2]

def run(stim_comp):
    cell.comp(stim_comp).stimulate(...)
    return jx.integrate(cell)

vmapped_run = vmap(run)
vmapped_run(stim_comps)

This is, in principle, supported for JAX arrays:

def simple_test(index):
    new_array = my_array.at[index].set(3.0)
    return jnp.sum(new_array ** 2)

vmapped_test = vmap(simple_test)
indices = jnp.asarray([0, 1, 2])
my_array = jnp.asarray([3.2, 4.3, 5.8])

@jnsbck jnsbck changed the title New Graph Backend and SWC to Jaxley Pipeline New Graph Backend Jun 10, 2025
@jnsbck jnsbck changed the title New Graph Backend New Pytree Backend Jun 10, 2025
@jnsbck
Copy link
Copy Markdown
Contributor Author

jnsbck commented Jun 10, 2025

I think the hard problem is solving the following:

Suppose you insert a channel into only a few compartments. Then there are essentially two variants:

Version 1, where we pad with NaNs

nodes = {"idx": jnp.array([0,1,2,3, ...])}
# insert channel into comp 0
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1, NaN, NaN, NaN, ...])
         }
# insert channel into comp 2
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
         }

Version 2, where we make it sparse, but instead need to keep track of what was inserted where etc (which is tricky with shared params, i.e. what if two different channels both share calcium_conc as a state, but one is inserted into [0,1] and the other into [1,2] ).

nodes = {"idx": jnp.array([0,1,2,3, ...])}
# insert channel into comp 0
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1, ...]),
         "Na": jnp.array([True, False, False, False]),
         }
# insert channel into comp 2
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1,0.1, ...]),
         "Na": jnp.array([True, False, True, False]),
         }

I have so far not been able to come up with a good way to do the below. Hence, despite the potential memory inefficiency, I prefer version 1 since it is easier to work with and understand. Also jax.experimental.sparse is something we could think about.

Assuming we are not going to find a good way to do 2) and have to pad using NaNs, I see 2 ways to do this:

  1. Work with {"idx": jnp.array([0,1,2,3,...]), "param": jnp.array([]),...} directly
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
         }

and carry around all NaNs.

  1. Work with {0: {"gNa": 0.1, ...}, ...} and add a to_jax equiv. to add transpose the pytree and pad with NaNs when needed, i.e. via padding + jax.tree_util.tree_map(lambda *args: jnp.array(args), *node_attrs.values())

I slightly prefer 1., since we need the NaNs anyways. Inserting for example via nodes["gNa"] = jnp.nan*jnp.ones(num_comps); nodes["gNa"] = nodes["gNa"].at[idxs].set(0.1) is not actually that bad and we would not have to have the to_jax function. The only thing I do not like, is that it is a bit harder to parse for the user / traverse.

I hope the above makes some sense. Lemme know your thoughts.

@michaeldeistler
Copy link
Copy Markdown
Contributor

michaeldeistler commented Jun 10, 2025

This all makes sense, and I agree with your judgement on what the best solution is:

nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
         }

@jnsbck
Copy link
Copy Markdown
Contributor Author

jnsbck commented Jun 11, 2025

I spent some more time thinking about this and sth. like the following should also work, since we still have to loop over channels / synapses anyways .

morph = {
    "nodes": {
        "idx": jnp.array([0, 1, 2, 3]),
        "params": {
            "radius": jnp.array([1.0, 1.0, 1.0, 1.0]),
            "capacitance": jnp.array([1.0, 1.0, 1.0, 1.0]),
        },
        "states": {
            "v": jnp.array([0.0, 0.0, 0.0, 0.0]),
        },
        #...
    },
    "edges": {
        "idx": jnp.array([[0, 0], [1, 2]]),
        #...
    },
}

mechanisms = {
    "channels": {
        "Na1": {
            "Na1_idx": jnp.array([0, 2]),
            "params": {
                "Na1_g": jnp.array([0.1, 0.1]),
                "shared_param": jnp.array([0.1, 0.1]),
            },
            "states": {
                "Na1_m": jnp.array([0.1, 0.1]),
                "Na1_h": jnp.array([0.1, 0.1]),
            },
        },
        "Na2": {
            "Na2_idx": jnp.array([0, 1]),
            "params": {
                "Na2_g": jnp.array([0.1, 0.1]),
                "shared_param": jnp.array([0.1, 0.1]),
            },
            "states": {
                "Na2_m": jnp.array([0.1, 0.1]),
                "Na2_h": jnp.array([0.1, 0.1]),
            },
        },
    },
    "synapses": {
        "Ionotropic": {
            "Ionotropic_idx": jnp.array([[0, 0]]),
            "params": {
                "Ionotropic_g": jnp.array([0.1]),
            },
            "states": {
                "Ionotropic_s": jnp.array([0.1]),
            },
        },
    },
}

jaxley/jaxley/modules/base.py

Lines 2448 to 2474 in c6d11cd

indices = channel_nodes.index.to_numpy()
for channel in channels:
channel_param_names = list(channel.channel_params)
channel_param_names += [
"radius",
"length",
"axial_resistivity",
"capacitance",
]
channel_state_names = list(channel.channel_states)
channel_state_names += self.membrane_current_names
channel_indices = indices[channel_nodes[channel._name].astype(bool)]
channel_params = query_channel_states_and_params(
params, channel_param_names, channel_indices
)
channel_states = query_channel_states_and_params(
states, channel_state_names, channel_indices
)
states_updated = channel.update_states(
channel_states, delta_t, voltages[channel_indices], channel_params
)
# Rebuild state. This has to be done within the loop over channels to allow
# multiple channels which modify the same state.
for key, val in states_updated.items():
states[key] = states[key].at[channel_indices].set(val)

Could be modified to sth. like:

for name, channel_data in self.pytree["channels"].items():
    channel_voltages = voltages[channel_data[f"{name}_idx"]]
    states_updated = self.channels[name].update_states(channel_data["states"], delta_t, channel_voltages, channel_data["params"])

I see 2 problems with this approach though.

  1. It will either be messy to manipulate or have to be constructed from another representation i.e. {0: {"gNa": 0.1, ...}, ...}. 2. I am not exactly sure how to solve the problem of shared params, i.e. how to deal with shared_param above. The only way I can think of is to do something like the below, which I is something I have also attempted to implement in Simpler edge indexing #487
mechanisms = {
    "channels": {
        "Na1": {
            "Na1_idx": jnp.array([0, 2]),
            "params": {
                "Na1_g": jnp.array([0.1, 0.1]),
            },
            "states": {
                "Na1_m": jnp.array([0.1, 0.1]),
                "Na1_h": jnp.array([0.1, 0.1]),
            },
        },
        "Na2": {
            "Na2_idx": jnp.array([0, 1]),
            "params": {
                "Na2_g": jnp.array([0.1, 0.1]),
            },
            "states": {
                "Na2_m": jnp.array([0.1, 0.1]),
                "Na2_h": jnp.array([0.1, 0.1]),
            },
        },
        "__shared__": {
            "params": {
                "shared_param": jnp.array([0.1, 0.1, 0.1]),
            },
            "states": {
                "shared_state": jnp.array([0.1, 0.1, 0.1]),
            },
        },
    },
#...
}

This somewhat mimics the current channels / synapses structure and we could even think about using them directly instead of using dictionaries, which I actually like, i.e. inserting a channel into 2 compartments would do the following:

cell.comp([0,1]).insert(Na())
cell.comp([0]).set("Na_g", 0.2)
cell.channels["Na"].params["Na_g"] # -> jnp.array([0.1, 0.2])

(assuming cell.channels is a dict and not a list how we have it currently). Again, the only thing I am unsure about is param / state sharing.

Looking forward to hear your thoughts on this.

EDIT:

we can also remove the "states" and "params" level, but then the looping would not be as nice, since channel.update_states expects seperate inputs for states and params and we would have to filter based on channel.params.keys()

@michaeldeistler
Copy link
Copy Markdown
Contributor

#########################################
#########################################
#########################################
#########################################
mydict = {
    0: 3 * jnp.ones((5,)),
    1: 4 * jnp.ones((5,)),
    2: 5 * jnp.ones((5,)),
}

def stimulate(index):
    return mydict[index].at[2].add(2.0)

vmapped_stim = jax.vmap(stimulate)

# Test
stimulate(jnp.asarray(2))
arr = jnp.asarray([0, 1, 2])
vmapped_stim(arr)  # Fails


#########################################
#########################################
#########################################
#########################################
mydict = {
    "inds": jnp.asarray([0, 1, 2]),
    "vals": jnp.asarray([3.0, 4.0, 5.0]),
}

def stimulate(index):
    my_ind = mydict["inds"][index]
    return mydict["vals"][my_ind] + 2.0

vmapped_stim = jax.vmap(stimulate)

# Test
stimulate(jnp.asarray(2))
arr = jnp.asarray([0, 1, 2])
vmapped_stim(arr)

@michaeldeistler
Copy link
Copy Markdown
Contributor

df = net.edges  # this is a `property` that automatically builds the df.
df = df.query("pre_global_cell_index in [3, 4]")
net.select(edges=df.index).set("Ionotropic_gS", 0.2)

This is a bit annoying because it now requires the user to understand that there is a difference between pd_nodes and nodes:

filtered_node_inds = [net.cell(len(synapse_locations)).branch(b).loc(l).jax_nodes.index for b, l in zip(branches, locs)]
post = net.select(nodes=filtered_node_inds)

This is a bit annoying because it now requires the user to understand that there is a difference between pd_nodes and nodes:

filtered_node_inds = [net.cell(len(synapse_locations)).branch(b).loc(l).jax_nodes.index for b, l in zip(branches, locs)]

def simulate(vals):
  return cell.select(nodes= filtered_node_inds ).set("length", vals)

@jnsbck
Copy link
Copy Markdown
Contributor Author

jnsbck commented Jun 12, 2025

meeting summary:

We decided on the following structure:

nodes = {"idx": jnp.array([0,1,2,3, ...])}
# insert channel into comp 0
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1, NaN, NaN, NaN, ...])
         }
# insert channel into comp 2
nodes = {"idx": jnp.array([0,1,2,3, ...]),
         "gNa": jnp.array([0.1,NaN, 0.1, NaN, ...])
         }

We keep NaNs. If we become memory bound, we can still give jax.experimental.sparse a shot. (TODO for myself to try this out on a small scale).

Viewing will work on the backend, pandas will only be used to render the pytree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants