Skip to content
/ plaid Public

Multimodal generation of protein sequence and all-atom structure using latent diffusion, with compositional function and taxonomic prompts. http://bit.ly/plaid-proteins

License

Notifications You must be signed in to change notification settings

amyxlu/plaid

Repository files navigation

PLAID (Protein Latent Induced Diffusion)

Results showing protein design capabilities

PLAID is a multimodal generative model that can generate protein sequence and all-atom structure based on conditional function and taxonomic prompts. Please see our paper for more details.

Contents

Demo

A hosted demo of the model will be available soon.

Installation

Clone the Repository

git clone https://github.com/amyxlu/plaid.git
cd plaid

Environment Setup

Create the environment and install dependencies:

conda env create --file environment.yaml  # Create environment
pip install --no-deps git+https://github.com/amyxlu/openfold.git  # Install OpenFold
pip install -e .  # Install PLAID

Note: The OpenFold implementation of the ESMFold module includes custom CUDA kernels for the attention mechanism. This repository uses a fork of OpenFold with C++17 compatibility for CUDA kernels to support torch >= 2.0.

Model Weights

  • Latent Autoencoder (CHEAP): full codebase is available here. We use the CHEAP_pfam_shorten_2_dim_32() model.
  • Diffusion Weights (PLAID): Hosted on HuggingFace. There is both a 2B and a 100M model.

By default, PLAID weights are cached in ~/.cache/plaid and CHEAP latent autoencoder weights in ~/.cache/cheap. Customize the cache path using:

echo "export CHEAP_CACHE=/path/to/cache" >> ~/.bashrc  # see CHEAP README for more details
echo "export PLAID_CACHE=/path/to/cache" >> ~/.bashrc

Loading Pretrained Models

from plaid.pretrained import PLAID_2B, PLAID_100M
denoiser, cfg = PLAID_2B()

This loads the PLAID DiT denoiser, and the hyperparameters used to initialize the diffusion object defined in src/plaid/diffusion/cfg.py. The denoiser and diffusion configuration is loaded separately, since in theory, the denoiser can be used with any other diffusion setup, such as EDM. Using the sampling steps below will initialize the discrete diffusion process used in our paper.

Basic Usage

The run_pipeline.py script offers an entry point to the full pipeline, which consists of:

  1. Sampling latent embeddings.
  2. Decoding these embeddings into sequences and structures.
  3. Folding and inverse folding acrobatics to report self-consistency and cross-consistency statics.
  4. Compute analysis metrics, including Foldseek and MMseqs to compare generations to known protein sequence and/or structure.

Commands in this Basic Usage section will only run steps 1 & 2. To run the full pipeline including evaluations, see the Full Pipeline and Evaluation sections.

Important

The specified length is half the actual protein length and must be divisible by 4. For example, to generate a 200-residue protein, set length=100.

Quick Start: Command line

Unconditional Sampling

SAMPLE_OUTPUT_DIR=/shared/amyxlu/plaid/artifacts/samples
python pipeline/run_pipeline.py experiment=generate_unconditional ++sample.output_root_dir=$SAMPLE_OUTPUT_DIR ++sample.length=60 ++sample.num_samples=16

Note that ++sample.output_root_dir has no default, and must be defined. Other defaults are defined in configs/inference/sample/ddim_unconditional.yaml.

This will save outputs to SAMPLE_OUTPUT_DIR/f2219_o3617_l60_s3/, where f2219 refers to the unconditional function index, o3617 refers to the unconditional organism index, and l60 refers to the latent length.

Conditional Sampling

In this example, we're generating proteins with 6-phosphofructokinase activity from E. coli.

SAMPLE_OUTPUT_DIR=/shared/amyxlu/plaid/artifacts/samples
python pipeline/run_pipeline.py experiment=generate_conditional ++sample.output_root_dir=$SAMPLE_OUTPUT_DIR ++sample.function_idx=166  ++sample.organism_idx=1030 ++sample.length=None ++sample.cond_scale=3.0

++sample.function_idx and ++sample.organism_idx are required. Similar to the unconditional case, ++sample.output_root_dir has no default, and must be defined. The other default values are this time specified in configs/sample_conditional.yaml. When ++sample.length=None, the length is automatically chosen based on the length of known Pfam domains with the function++sample.function_idx. This auto-length feature only works when conditioning on a function. The conditioning scale of 3.0 determines how strongly to condition - a scale of 0.0 would be equivalent to unconditional sampling.

This will save outputs to SAMPLE_OUTPUT_DIR/f166_o1030_l140_s3/, where f166 refers to the conditional function index, o1030 refers to the conditional organism index, and s3 refers to the classifier-free guidance conditioning. l140 is the auto-selected length. This might be different for different runs.

Tip

To find the mapping between your desired GO term and function index, see src/plaid/constants.py.

Quick Start: Notebook

You can also call the modular classes directly in a notebook, which affords some flexibilities; for example, here, we can specify the GO term and organism directly as a string. See the conditional_demo.ipynb notebook for an example.

Full Pipeline

The entire pipeline/run_pipeline.py script will run the full pipeline, including sampling, decoding, consistency, and analysis. See configs/inference/full.yaml for the full pipeline config.

You can also run each of these steps as individual scripts, if you need to resume from a pipeline step after an error. Scripts for each step are located in pipeline. These scripts are wrappers for the logic defined in src/plaid/pipeline.

Step 1: Sampling Latent Embeddings

  1. Run latent sampling using Hydra-configured scripts in configs/pipeline/sample/. Example commands:
SAMPLE_OUTPUT_DIR=/shared/amyxlu/plaid/artifacts/samples

# Conditional sampling with inferred length
python pipeline/run_sample.py ++length=null ++function_idx=166 ++organism_idx=1326 ++sample.output_root_dir=$SAMPLE_OUTPUT_DIR`

# Conditional sampling with fixed length
python pipeline/run_sample.py ++length=200 ++function_idx=166 ++organism_idx=1326 ++sample.output_root_dir=$SAMPLE_OUTPUT_DIR`

# Unconditional sampling with specified output directory
python pipeline/run_sample.py ++length=200 ++function_idx=2219 ++organism_idx=3617 ++sample.output_root_dir=$SAMPLE_OUTPUT_DIR

Tip

PLAID also supports the DPM++ sampler, which achieves comparable performance with fewer sampling steps. See configs/inference/sample/dpm2m_sde.yaml for more details.

Step 2: Decode the Latent Embedding

  • 2a. Uncompress latent arrays using the CHEAP autoencoder.
  • 2b. Use the CHEAP sequence decoder for sequences.
  • 2c. Use the ESMFold structure encoder for structures.
python pipeline/run_pipeline.py experiment=generate_unconditional ++npz_path=$SAMPLE_OUTPUT_DIR/f2219_o3617_l60_s3/latent.npz

Note that the code in ++npz_path depends on which specifications were used in Step 1, but this path always ends in latent.npz.

Step 3: Generate inverse and phantom sequences/structures

python pipeline/run_consistency.py ++samples_dir=/path/to/samples

Step 4: Analyze metrics (ccRMSD, novelty, diversity, etc.):

python pipeline/run_analysis.py /path/to/samples

Training

Train PLAID models using PyTorch Lightning with distributed data parallel (DDP). Example launch command for training on 8 A100 GPUs:

python train_compositional.py  # see config/experiments

Key features:

  • Min-SNR loss scaling
  • Classifier-free guidance (GO terms and organisms)
  • Self-conditioning
  • EMA weight decay

Note: If using torch.compile, ensure precision is set to float32 due to compatibility issues with the xFormers library.

Embeddings are pre-computed and cached as .tar files for compatibility with WebDataset dataloaders. Pfam embedding .tar files used for training and validation data will be uploaded soon.

License

PLAID is licensed under the MIT License. See the LICENSE file for details.

About

Multimodal generation of protein sequence and all-atom structure using latent diffusion, with compositional function and taxonomic prompts. http://bit.ly/plaid-proteins

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published