Neural ODE-based training pipeline for modeling cellular dynamics in spatial-temporal single-cell data.
This repository contains training code for the DynVelocity model v15, which uses neural ordinary differential equations (ODEs) to model continuous cell state transitions over time. The model combines gene expression dynamics with spatial movement, making it particularly suitable for spatial transcriptomics data like MERFISH.
- Neural ODE Framework: Uses
torchdiffeqfor continuous-time modeling - Multi-modal Input:
- Gene expression features (30D latent embeddings)
- 3D spatial coordinates
- Spatial Attention: Message passing between neighboring cells
- Dual Velocity Prediction:
- Expression velocity head (spectral normalized MLPs)
- Position velocity head (E3NN equivariant network)
- Energy Regularization: Optional kinetic energy term for smooth trajectories
- Data Loading: Load AnnData objects with expression, spatial, and temporal information
- Mini-batch Sampling: Random sampling from each timepoint for scalability
- ODE Integration: Forward simulation from initial to target timepoints
- Loss Computation: Optimal transport distance between predicted and true distributions
- Evaluation: Coverage metrics and spatial structure preservation assessment
# Core ML libraries
torch
torchdiffeq
numpy
pandas
matplotlib
# Configuration management
hydra-core
omegaconf
# Single-cell analysis
scanpy
anndata
# Optimal transport
ot # POT library
geomloss # Optional, for GeomLoss
# Spatial analysis
scipy
sklearn
# Logging and visualization
wandb
tqdmInstallation:
pip install torch torchdiffeq hydra-core omegaconf numpy pandas matplotlib
pip install scanpy anndata scipy scikit-learn wandb tqdm
pip install pot # For optimal transport
pip install geomloss # Optional, for additional OT methodsThe code requires custom dynamica modules:
dynamica.sat.SpatialAttentionLayerdynamica.equi.E3NNVelocityPredictor
Paths are currently hardcoded to:
/scratch/users/chensj16/codes/dynode_training/scratch/users/chensj16/codes/dynode_development
- CUDA-compatible GPU (configured for
cuda:0) - Sufficient memory for large datasets (current data file is 1.5GB)
- File:
adata_list.v250715.pt(PyTorch tensor file) - Format: List of AnnData objects, one per timepoint
- Required keys:
adata.obsm['Z_mean']: Gene expression embeddings (30D)adata.obsm['std_3D']: Standardized 3D spatial coordinatesadata.obs['CombinedCellType']: Cell type annotationsadata.obs['time']: Timepoint information
python train.pyConfiguration is managed through Hydra with YAML files. Default settings are in config.yaml, and can be overridden via command line.
python train.py# Modify learning rate
python train.py training.lr=1e-6
# Change epochs and batch size
python train.py training.n_epochs=1000 training.mini_batch_size=512
# Enable debug mode (1 epoch only)
python train.py debug_mode=true
# Switch device
python train.py device=cuda:1# Modify evaluation timepoints
python train.py evaluation.eval_timepoints="[[1,2],[3,4]]"
# Change learning rate schedule
python train.py training.lr_schedule="{50:1e-5,100:5e-6}"
# Disable features
python train.py logging.use_wandb=false evaluation.enable_eval=false# Use comprehensive training plan (16 forward/reverse sequences)
python train.py --config-name=config_full
# Use custom config file
python train.py --config-name=my_config# Quick debug run
python train.py debug_mode=true evaluation.enable_eval=false
# Reduced evaluation overhead
python train.py evaluation.eval_samples=1000 training.mini_batch_size=512
# High learning rate experiment
python train.py training.lr=1e-5 training.lr_schedule="{10:5e-6,50:1e-6}"
# Full bidirectional training with comprehensive sequences
python train.py --config-name=config_full training.n_epochs=1000- Learning Rate Scheduling: Automatic LR adjustments at specified epochs
- Gradient Clipping: Prevents gradient explosion (starts at epoch 100)
- Integrated Evaluation: Real-time assessment every N epochs
- WandB Logging: Experiment tracking and visualization
- Checkpointing: Model saves every 20 epochs
- Coverage: Percentage of true cells within learned radius of predictions
- Gromov-Wasserstein: Spatial structure preservation metric
- Label Consistency: Cell type preservation through trajectories
- Velocity Norms: Analysis of expression and position velocity magnitudes
- Model Checkpoints: Saved to
/scratch/users/chensj16/codes/dynode_training/mouse-data-tmp/ - Logs: Written to
./logs/directory - WandB Dashboard: Project
fgw-train250519
- 2025-07-16: Initial commit with DynVelocity v15 training pipeline
[ ] add loss support for pot_partial_ot and pot_fugw