This repository contains the code for "RAG-ESM": a method for training a RAG model based on the ESM2 pre-trained models. The model uses cross attention layers to improve the performance of ESM2 models by conditioning the generation of masked protein sequences on embeddings of sequences that are homologous to the masked sequence. The model is trained on the OpenProteinSet dataset. The model was also trained on the discrete diffusion task using a variable masking probability, therefore it can be used to generate sequences with different levels of noise (masking) via a denoising process.
You will find the code for training the model in the src
directory. The code is organized as follows:
src/rag_esm: contains the code for the RAG-ESM model
configs
: contains the configuration files used to train the modeltrain_8M.yaml
: contains the hyperparameters used to train the RAG-ESM (12M) model whiletrain.yaml
is a more generic configuration file used to train any ESM model.setup.yaml
: contains base configuration for hydra and wandb setup.
modules
: contains the architecture modules and the the dataset class used to train the modeldataloaders.py
: contains the dataset class used to train the model and the collate function used to create the batchesmodel.py
: contains the architecture of the RAG-ESM modelesm_decoder.py
: contains the modified ESM2 modules used in the RAG-ESM model. The main differences with respect to ESM-2 are the addition of the cross attention layers and the usage of Flash-Attention.
utils
: contains some useful functions used to train the model and parse the dataset, e.g.metrics.py
,trainer.py
andhamming.py
, and a jupyter notebook with the code used to create the training set.generate.py
: contains useful functions to sample sequences from the model. In particulardenoise
can be used to sample using the diffusion process.
train.py
: contains the training script for the model.sample.py
: contains the script to sample/generate sequences from the model.
To train your model you can use the following command:
python src/rag_esm/train.py
To modify the hyperparameters change the train.yaml
file in the src/rag_esm/configs
directory.
This project is licensed under the LICENSE file in the root directory of the project.
The initial code of this repository has been initiated by the Python Machine Learning Research Project Template with the LICENSE.ml-template file.
Additional LICENSE files may be present in subdirectories of the project.