A small modification to bpnetlite's BPNet to accomodate large validation datasets.
Redid the validation loop to work with a PyTorch DataLoader (e.g., one generated by GenVarLoader), rather than having to load the whole validation set into memory at once. Also, the model checkpoints save the optimizer state dict, epoch number, and number of steps since last improvement in addition to the model state dict, so that training can be resumed from a checkpoint w/ the correct optimizer and early stopping/epoch states.
Additionally, we include a Pytorch implementation of CLIPNET, which is essentially BPNet with added batch norm layers, similar to what was done with the original CLIPNET implementation in tensorflow.
Clone and install github repo:
pip install git+https://github.com/adamyhe/personalbpnet.gitThe PersonalBPNet and CLIPNET classes can be directly imported:
from personal_bpnet.personal_bpnet import PersonalBPNet
from personal_bpnet.clipnet_pytorch import CLIPNETThe PersonalBPNet class is identical to the BPNet class from bpnetlite, but its fit method has been modified to accept a PyTorch DataLoader for validation data, rather than fixed tensors. This significantly improves the memory footprint of the validation step. Do note that random data augmentations (jittering, reverse complement) should be turned off to return a fixed validation dataset.
The CLIPNET class modifies PersonalBPNet to include batch normalization layers after each convolutional and linear layer, which we've found to improve prediction accuracy.
We've deposited pre-trained CLIPNET PyTorch weights on Zenodo. These were trained using the same multi-individual LCL PRO-cap dataset as the original CLIPNET models (training data also available on Zenodo). We've only saved the model weights, so you'll need to initialize the models, then use load_state_dict:
import os
from personal_bpnet.clipnet_pytorch import CLIPNET
os.system("wget https://zenodo.org/records/14632152/files/lcl_procap_models.tar --quiet")
os.system("tar -xvf https://zenodo.org/records/14632152/files/lcl_procap_models.tar")
model = CLIPNET(
n_filters=512, n_outputs=2, n_control_tracks=0, n_layers=8, trimming=(2114-1000) // 2
)
model.load_state_dict(m)
model.load_state_dict(torch.load("lcl_procap_models/f1.torch")) # 9 model replicates.IMPORTANT: The pretrained CLIPNET PyTorch models have been trained on half two-hot encoded sequences. That is, homozygous positions are represented with one-hot encodings of the 4 nucleotides and the heterozygous positions are represented as [0.5, 0.5, 0, 0], [0.5, 0, 0.5, 0], ....
We also provide a method to load the original, Tensorflow-trained models into PyTorch. The CLIPNET_TF.from_tf class method can be used to load hdf5 weights into a PyTorch module.
import os
from personal_bpnet.clipnet_tensorflow import CLIPNET_TF
os.makedirs("clipnet_models/", exist_ok=True)
for i in range(1, 10):
os.system(f"wget https://zenodo.org/records/10408623/files/fold_{i}.h5 -P clipnet_models/")
models = [
CLIPNET_TF.from_tf(f"clipnet_models/fold_{i}.h5") for i in range(1, 10)
]These models have been converted to expecting inputs of the shape (N, 4, 1000). IMPORTANT: Models loaded in this fashion still expect inputs to be two-hot encoded (see description in the README of our TensorFlow package). For compatibility with packages that only allow one-hot encoded sequences, please use the TwoHotToOneHot wrapper:
from personal_bpnet.clipnet_tensorflow import TwoHotToOneHot
ohe_models = [TwoHotToOneHot(m) for m in models]This class works for all models trained using the rnn_v10 architecture in the original CLIPNET repo. At present, this includes
- the original LCL PRO-cap models,
- K562 PRO-cap models (fine-tuned from the above),
- ablated LCL PRO-cap models.
We also provide a PauseNet class (no published models yet because I haven't gotten this to work well). This is designed to be a wrapper around bpnetlite.bpnet.BPNet, PersonalBPNet, or CLIPNET models that transforms them to predict a single scalar output per input sequence. This is designed for fine-tuning the base-resolution models to predicting regulatory phenotypes that can only be represented as a single scalar value per region (e.g., pausing index, for which this class is named). The intended use for this class is as follows:
from personal_bpnet import CLIPNET, PauseNet
# This is for loading from a weights dictionary.
# If you saved the full model, just directly use pretrain=torch.load("weights.torch")
pretrain = CLIPNET(**init_args)
pretrain.load_state_dict(torch.load("weights.torch"))
model = PauseNet(pretrain)
model.fit(**params)This package is currently in active dev and may change drastically. Models have not been extensively benchmarked yet. May be lots of typos/copy paste errors. A personalized ChromBPNet fitting method has not been included, as I personally have not had success training such models.
For convenience, prediction and attribution (DeepLIFT/SHAP) methods for CLIPNET or PauseNet models can be accessed via a CLI:
clipnet predict -h
clipnet predict_tss -h
clipnet attribute -h
pausenet predict -h
pausenet attribute -h