This repository implements inference and parameter learning for a Tree Autoregressive Hidden Markov Model (Tree AR-HMM) using Dynamax (Jax). This was inspired by the problem of modeling cell lineages, where we would like to learn latent states of cells, but naively applying an AR-HMM treats daughter cells independently of their parent. Division events create a branching tree structure, and we assume that the latent state of a daughter cell depends on its parent's latent state at division. The implementation also accommodates spontaneous cell birth (e.g. coming into frame at time
In summary:
- Each cell (or agent, more generally) has a latent discrete state that evolves over time.
- Emissions are shared AR(1) dynamics conditioned on the latent state.
- When a cell divides, its daughters’ initial latent states are drawn from a division-specific transition kernel that depends on the parent’s latent state at division.
We index cells by
Each cell has a lifetime interval over which it exists and can persist or divide.
For each root cell
-
Initial state
$$z_{r,1} \sim \mathrm{Cat}(\pi_0)$$ -
Within-cell transitions
$$z_{r,t+1} \mid z_{r,t} = k \sim \mathrm{Cat}(\pi_k), \quad t = 1,\dots,\tilde t - 1$$
Here,
For any cell
Conditional on the latent state, we use a shared AR(1) emission model:
For example, a Gaussian AR(1):
All emission parameters
Suppose a parent cell
Each daughter’s initial latent state at time
The collection
After birth, each daughter follows the same within-cell transition dynamics:
again with AR(1) emissions as above.
Let
where the transition term
- the standard transition matrix
$P$ for non-division (self) transitions; - the division transition matrix
$\tilde P$ for division events.
Fit using EM-- see the Derivation folder for forward–backward details.
- Optionally, division events themselves can be modeled by appending a division indicator to the observations.
- Tests
- Improve sampling and demonstration notebook
See the notebook for usage.

