Skip to content

adamyhe/PersonalBPNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

74 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PersonalBPNet

A small modification to bpnetlite's BPNet to accomodate large validation datasets.

Redid the validation loop to work with a PyTorch DataLoader (e.g., one generated by GenVarLoader), rather than having to load the whole validation set into memory at once. Also, the model checkpoints save the optimizer state dict, epoch number, and number of steps since last improvement in addition to the model state dict, so that training can be resumed from a checkpoint w/ the correct optimizer and early stopping/epoch states.

Additionally, we include a Pytorch implementation of CLIPNET, which is essentially BPNet with added batch norm layers, similar to what was done with the original CLIPNET implementation in tensorflow.

Installation

Clone and install github repo:

pip install git+https://github.com/adamyhe/personalbpnet.git

PyTorch API

The PersonalBPNet and CLIPNET classes can be directly imported:

from personal_bpnet.personal_bpnet import PersonalBPNet
from personal_bpnet.clipnet_pytorch import CLIPNET

The PersonalBPNet class is identical to the BPNet class from bpnetlite, but its fit method has been modified to accept a PyTorch DataLoader for validation data, rather than fixed tensors. This significantly improves the memory footprint of the validation step. Do note that random data augmentations (jittering, reverse complement) should be turned off to return a fixed validation dataset.

The CLIPNET class modifies PersonalBPNet to include batch normalization layers after each convolutional and linear layer, which we've found to improve prediction accuracy.

We've deposited pre-trained CLIPNET PyTorch weights on Zenodo. These were trained using the same multi-individual LCL PRO-cap dataset as the original CLIPNET models (training data also available on Zenodo). We've only saved the model weights, so you'll need to initialize the models, then use load_state_dict:

import os
from personal_bpnet.clipnet_pytorch import CLIPNET

os.system("wget https://zenodo.org/records/14632152/files/lcl_procap_models.tar --quiet")
os.system("tar -xvf https://zenodo.org/records/14632152/files/lcl_procap_models.tar")
model = CLIPNET(
    n_filters=512, n_outputs=2, n_control_tracks=0, n_layers=8, trimming=(2114-1000) // 2
)
model.load_state_dict(m)
model.load_state_dict(torch.load("lcl_procap_models/f1.torch")) # 9 model replicates.

IMPORTANT: The pretrained CLIPNET PyTorch models have been trained on half two-hot encoded sequences. That is, homozygous positions are represented with one-hot encodings of the 4 nucleotides and the heterozygous positions are represented as [0.5, 0.5, 0, 0], [0.5, 0, 0.5, 0], ....

Porting TensorFlow weights to PyTorch

We also provide a method to load the original, Tensorflow-trained models into PyTorch. The CLIPNET_TF.from_tf class method can be used to load hdf5 weights into a PyTorch module.

import os
from personal_bpnet.clipnet_tensorflow import CLIPNET_TF

os.makedirs("clipnet_models/", exist_ok=True)
for i in range(1, 10):
    os.system(f"wget https://zenodo.org/records/10408623/files/fold_{i}.h5 -P clipnet_models/")

models = [
    CLIPNET_TF.from_tf(f"clipnet_models/fold_{i}.h5") for i in range(1, 10)
]

These models have been converted to expecting inputs of the shape (N, 4, 1000). IMPORTANT: Models loaded in this fashion still expect inputs to be two-hot encoded (see description in the README of our TensorFlow package). For compatibility with packages that only allow one-hot encoded sequences, please use the TwoHotToOneHot wrapper:

from personal_bpnet.clipnet_tensorflow import TwoHotToOneHot

ohe_models = [TwoHotToOneHot(m) for m in models]

This class works for all models trained using the rnn_v10 architecture in the original CLIPNET repo. At present, this includes

PauseNet

We also provide a PauseNet class (no published models yet because I haven't gotten this to work well). This is designed to be a wrapper around bpnetlite.bpnet.BPNet, PersonalBPNet, or CLIPNET models that transforms them to predict a single scalar output per input sequence. This is designed for fine-tuning the base-resolution models to predicting regulatory phenotypes that can only be represented as a single scalar value per region (e.g., pausing index, for which this class is named). The intended use for this class is as follows:

from personal_bpnet import CLIPNET, PauseNet

# This is for loading from a weights dictionary.
# If you saved the full model, just directly use pretrain=torch.load("weights.torch")
pretrain = CLIPNET(**init_args)
pretrain.load_state_dict(torch.load("weights.torch"))

model = PauseNet(pretrain)
model.fit(**params)

This package is currently in active dev and may change drastically. Models have not been extensively benchmarked yet. May be lots of typos/copy paste errors. A personalized ChromBPNet fitting method has not been included, as I personally have not had success training such models.

Command line interface

For convenience, prediction and attribution (DeepLIFT/SHAP) methods for CLIPNET or PauseNet models can be accessed via a CLI:

clipnet predict -h
clipnet predict_tss -h
clipnet attribute -h

pausenet predict -h
pausenet attribute -h

About

A small modification to bpnetlite's BPNet to accomodate large validation datasets.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages