This repository contains a workflow to benchmark the embedding quality of genomic foundation models on mRNA specific tasks. The mRNABench contains a catalogue of datasets and training split logic which can be used to evaluate the embedding quality of several catalogued models.
Jump to: Model Catalog Dataset Catalog
Several configurations of the mRNABench are available.
If you are interested in the benchmark datasets only, you can run:
pip install mrna-benchThe inference-capable version of mRNABench that can generate embeddings using most models (except Evo2 and Helix mRNA) can be installed as shown below. Note that this requires PyTorch version 2.2.2 with CUDA 12.1.
conda create --name mrna_bench python=3.10
conda activate mrna_bench
pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121
pip install mrna-bench[base_models]Inference with other models will require the installation of the model's dependencies first, which are usually listed on the model's GitHub page (see below).
Inference using Evo2 requires installing the following in its own environment. Note, I had an issue where the evo_40b models, when downloaded, had their merged checkpoints stored one directory above the huggingface hub. I had to manually move the checkpoint into its corresponding snapshot directory. /hub/models--arcinstitute-evo2_40b*/snapshots/snapshot_name/
conda create --name evo_bench python=3.11
conda activate evo_bench
conda install conda-forge::gcc # need updated gcc version
cd path/to/mRNA/bench
pip install -e .
git clone --recurse-submodules git@github.com:ArcInstitute/evo2.git
cd path/to/evo2
pip install .
pip install transformer_engine[pytorch]==1.13Important
After installation, please run the following in Python to set where data associated with the benchmarks will be stored.
import mrna_bench as mb
path_to_dir_to_store_data = "DESIRED_PATH"
mb.update_data_path(path_to_dir_to_store_data)
path_to_dir_to_store_weights = "/data1/morrisq/ian/rna_benchmarks/model_weights"
mb.update_model_weights_path(path_to_dir_to_store_weights)Datasets can be retrieved using:
import mrna_bench as mb
dataset = mb.load_dataset("go-mf")
data_df = dataset.data_dfThe mRNABench can also be used to test out common genomic foundation models:
import torch
import mrna_bench as mb
from mrna_bench.embedder import DatasetEmbedder
from mrna_bench.linear_probe import LinearProbeBuilder
device = torch.device("cuda")
dataset = mb.load_dataset("go-mf")
model = mb.load_model("Orthrus", "orthrus-large-6-track", device)
embedder = DatasetEmbedder(model, dataset)
embeddings = embedder.embed_dataset()
embeddings = embeddings.detach().cpu().numpy()
prober = (LinearProbeBuilder(dataset)
.fetch_embedding_by_embedding_instance("orthrus-large-6", embeddings)
.build_splitter("homology", species="human", eval_all_splits=False)
.build_evaluator("multilabel")
.set_target("target")
.build()
)
metrics = prober.run_linear_probe(2541)
print(metrics)Also see the scripts/ folder for example scripts that uses slurm to embed dataset chunks in parallel for reduce runtime, as well as an example of multi-seed linear probing.
The models supported by the base_models installation are catalogued below.
| Model Name | Model Versions | Description | Citation |
|---|---|---|---|
Orthrus |
orthrus-large-6-trackorthrus-base-4-track |
Mamba-based RNA FM pre-trained using contrastive learning on ~45M RNA transcripts to capture functional and evolutionary relationships. | [Code] [Paper] |
RNA-FM |
rna-fm mrna-fm |
Transformer-based RNA FM pre-trained using MLM on 23M ncRNA sequences. mRNA-FM trained on mRNA CDS regions using codon tokenizer. | [Github] |
DNABERT2 |
dnabert2 |
Transformer-based DNA FM pre-trained using MLM on multispecies genomic dataset. Uses BPE and other modern architectural improvements for efficiency. | [Github] |
NucleotideTransformer |
2.5b-multi-species 2.5b-1000g 500m-human-ref 500m-1000g v2-50m-multi-species v2-100m-multi-species v2-250m-multi-species v2-500m-multi-species |
Transformer-based DNA FM pre-trained using MLM on a variety of possible datasets at various model sizes. Sequence is tokenized using 6-mers. | [Github] |
HyenaDNA |
hyenadna-large-1m-seqlen-hf hyenadna-medium-450k-seqlen-hf hyenadna-medium-160k-seqlen-hf hyenadna-small-32k-seqlen-hf hyenadna-tiny-16k-seqlen-d128-hf |
Hyena-based DNA FM pre-trained using NTP on the human reference genome. Available at various model sizes and pretraining sequence contexts. | [Github] |
SpliceBERT |
SpliceBERT.1024nt SpliceBERT-human.510nt SpliceBERT.510nt |
Transformer-based RNA foundation model trained on 2M vertebrate mRNA sequences using MLM. Alternative versions trained on only human RNA, and with smaller context windows. | [Github] |
RiNALMo |
rinalmo |
Transformer-based RNA foundation model trained on 36M ncRNA sequences using MLM and other modern architectural improvements such as RoPE, SwiGLU activations, and Flash Attention. | [Github] |
UTR-LM |
utrlm-te_el utrlm-mrl |
Transformer-based RNA foundation model that is pre-trained on random and endogenous 5'UTR sequences from various species using MLM. | [Github] |
3UTRBERT |
utrbert-3mer utrbert-4mer utrbert-5mer utrbert-6mer |
Transformer-based RNA foundation model that is pre-trained on the 3'UTR regions of 100K RNA sequences using MLM. | [Github] |
RNA-MSM |
rnamsm |
Transformer-based RNA foundation model trained by using MSA from custom structure-based homology map on roughly 8M RNA sequences. | [Github] |
RNAErnie |
rnaernie |
Transformer-based RNA foundation model trained using MLM at various mask sizes on 23M ncRNA sequences. | [Github] |
RNABERT |
rnabert |
Transformer-based RNA foundation model trained using MLM and a structural alignment objective on 80K ncRNA sequences | [Github] |
ERNIE-RNA |
ernierna ernierna-ss |
Transformer-based RNA foundation model trained using MLM with structural information added as attention mask biases. Pretrained on 20M ncRNA sequences. | [Github] |
Many of the models wrappers (3UTRBERT, RiNALMo, UTR-LM, RNA-MSM, RNAErnie) use reimplementations from the multimolecule package. See their website for more details.
All models should inherit from the template EmbeddingModel. Each model file should lazily load dependencies within its __init__ methods so each model can be used individually without install all other models. Models must implement get_model_short_name(model_version) which fetches the internal name for the model. This must be unique for every model version and must not contain underscores. Models should implement either embed_sequence or embed_sequence_sixtrack (see code for method signature). New models should be added to MODEL_CATALOG.
The current datasets catalogued are:
| Dataset Name | Catalogue Identifier | Description | Tasks | Citation |
|---|---|---|---|---|
| GO Molecular Function | go-mf |
Classification of the molecular function of a transcript's product as defined by the GO Resource. | multilabel |
website |
| Mean Ribosome Load (Sugimoto) | mrl‑sugimoto |
Mean ribosome load (MRL) per transcript isoform as measured in Sugimoto et al. 2022. | regression |
paper |
| RNA Half-life (Human) | rnahl‑human |
RNA half-life of human transcripts collected by Agarwal et al. 2022. | regression |
paper |
| RNA Half-life (Mouse) | rnahl‑mouse |
RNA half-life of mouse transcripts collected by Agarwal et al. 2022. | regression |
paper |
| Protein Subcellular Localization | prot‑loc |
Subcellular localization of transcript protein product defined in Protein Atlas. | multilabel |
website |
| Mean Ribosome Load (Sample) | mrl‑sample‑egfp mrl‑sample‑mcherrymrl‑sample‑designedmrl‑sample‑varying |
Mean ribosome load (MRL) measured in an MPRA of both random and designed 5'UTR regions (50nts) attached to a construct with either eGFP or mCherry. | regression |
paper |
| Protein Coding Gene Essentiality | pcg‑ess |
Essentiality of PCGs as measured by CRISPR knockdown. Log-fold expression and binary essentiality available on several cell lines. | regression classification |
paper |
New datasets should inherit from BenchmarkDataset. Dataset names cannot contain underscores. Each new dataset should download raw data and process it into a dataframe by overriding process_raw_data. This dataframe should store transcript as rows, using string encoding in the sequence column. If homology splitting is required, a column gene containing gene names is required. Six track embedding also requires columns cds and splice. The target column can have any name, as it is specified at time of probing. New datasets should be added to DATASET_CATALOG.