Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
vvyaltsev authored Mar 11, 2024
1 parent a0a66c1 commit 0ba3e90
Show file tree
Hide file tree
Showing 8 changed files with 1,448 additions and 0 deletions.
66 changes: 66 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
import json

import torch

@dataclass
class ArmNetConfig:
# DATA PATHS
train_data_path: Optional[str] = None
test_data_path: Optional[str] = None
train_bppm_data_path: Optional[str] = None
test_bppm_data_path: Optional[str] = None
pretrained_model_weights: Optional[str] = None

# MODEL PARAMETERS
hidden_dim: int = 192
head_size: int = 32
num_encoder_layers: int = 12
num_conv_layers: Optional[int] = None
conv_1d_kernel_size: int = 17
conv_2d_kernel_size: int = 3
dropout: float = 0.1
conv_1d_use_dropout: bool = False
use_bppm: bool = False

# TRIANING PARAMETERS
no_weights: bool = False
num_folds: int = 4
fold: int = 0
lr_max: float = 2.5e-3
weight_decay: float = 0.05
pct_start: float = 0.05
gradclip: float = 1.0
num_epochs: int = 200
num_workers: int = 32
batch_size: int = 128
batch_count: int = 1791
device: int = 0
seed: int = 2023

sgd_lr: float = 5e-5
sgd_num_epochs: int = 25
sgd_batch_count: int = 500
sgd_weight_decay: float = 0.05


def save(self, file_path):
with open(file_path, "w") as file:
json.dump(self.__dict__, file, indent=4)

def load(self, config_path):
with open(config_path, "r") as file:
config = json.load(file)
for key, value in config.items():
if key in self.__dict__:
setattr(self, key, value)

def load_dict(self, param_dict):
for key, value in param_dict.items():
if key in self.__dict__ and value is not None:
setattr(self, key, value)


32 changes: 32 additions & 0 deletions config_example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"train_data_path": "/path/to/train_data.parquet",
"test_data_path": "/path/to/test_sequences.csv",
"train_bppm_data_path": "/path/to/bppm/data/train",
"test_bppm_data_path": "/path/to/bppm/data/test",
"pretrained_model_weights": null,
"hidden_dim": 128,
"head_size": 32,
"num_encoder_layers": 4,
"num_conv_layers": null,
"conv_1d_kernel_size": 17,
"conv_2d_kernel_size": 3,
"dropout": 0.1,
"conv_1d_use_dropout": false,
"use_bppm": true,
"num_folds": 4,
"fold": 0,
"lr_max": 0.0025,
"weight_decay": 0.05,
"pct_start": 0.05,
"gradclip": 1.0,
"num_epochs": 100,
"num_workers": 32,
"batch_size": 128,
"batch_count": 1791,
"device": 0,
"seed": 2023,
"sgd_lr": 5e-05,
"sgd_num_epochs": 25,
"sgd_batch_count": 500,
"sgd_weight_decay": 0.05
}
260 changes: 260 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import math
from pathlib import Path
from typing import Optional, ClassVar

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold


def _load_bppm(
seq_id: str,
Lmax: int,
bppm_path: Path,
):
path = bppm_path / f"{seq_id}.npy"
mat = np.load(path)
dif = Lmax - mat.shape[0]
res = np.pad(mat, ((0, dif), (0, dif)))
return res


class BPPFeatures:
LMAX: ClassVar[int] = 206

def __init__(self, index_path: str, mempath: str):
self.index = self.read_index(index_path)
self.storage = self.read_memmap(mempath, len(self.index))

@classmethod
def read_index(cls, index_path):
with open(index_path) as inp:
ids = [line.strip() for line in inp]
index = {seqid: i for i, seqid in enumerate(ids)}

return index

@classmethod
def read_memmap(cls, memmap_path, index_len):
storage = np.memmap(memmap_path,
dtype=np.float32,
mode='r',
offset=0,
shape=(index_len, cls.LMAX, cls.LMAX),
order='C')
return storage

def __getitem__(self, seqid):
ind = self.index[seqid]
return self.storage[ind]


class RNA_Dataset(Dataset):
def __init__(self,
df,
mode: str='train',
seed: int = 2023,
fold: int = 0,
nfolds: int = 4,
use_bppm: bool = False,
bppm_path: Optional[Path] = None
):
self.seq_map = {'A':0,
'C':1,
'G':2,
'U':3,
"START": 4,
"END": 5,
"EMPTY": 6}

assert mode in ('train', 'eval')
df['L'] = df.sequence.apply(len)
self.Lmax = df['L'].max()



assert mode in ("train", "eval")

df_2A3 = df.loc[df.experiment_type=='2A3_MaP']
df_DMS = df.loc[df.experiment_type=='DMS_MaP']
split = list(KFold(n_splits=nfolds, random_state=seed,
shuffle=True).split(df_2A3))[fold][0 if mode=='train' else 1]
df_2A3 = df_2A3.iloc[split].reset_index(drop=True)
df_DMS = df_DMS.iloc[split].reset_index(drop=True)

if mode == "eval":
print("Keeping only clean data for validation")
m = (df_2A3['SN_filter'].values > 0) & (df_DMS['SN_filter'].values > 0)
df_2A3 = df_2A3.loc[m].reset_index(drop=True)
df_DMS = df_DMS.loc[m].reset_index(drop=True)

self.sid = df_2A3['sequence_id'].values
self.seq = df_2A3['sequence'].values
self.L = df_2A3['L'].values

self.react_2A3 = df_2A3[[c for c in df_2A3.columns if 'reactivity_0' in c]].values
self.react_DMS = df_DMS[[c for c in df_DMS.columns if 'reactivity_0' in c]].values
self.react_err_2A3 = df_2A3[[c for c in df_2A3.columns if 'reactivity_error_0' in c]].values
self.react_err_DMS = df_DMS[[c for c in df_DMS.columns if 'reactivity_error_0' in c]].values

self.is_good = ((df_2A3['SN_filter'].values > 0) & (df_DMS['SN_filter'].values > 0) )* 1
self.sn_2A3 = df_2A3['SN_filter'].values
self.sn_DMS = df_DMS['SN_filter'].values

sn = (df_2A3['signal_to_noise'].values + df_DMS['signal_to_noise'].values) / 2

sn = torch.from_numpy(sn)
self.weights = 0.5 * torch.clamp_min(torch.log(sn + 1.01),0.01)

self.mode = mode
self._use_bppm = use_bppm
if use_bppm:
if bppm_path is None:
raise ValueError("If use_bppm is set True, bppm_path must be specified.")
self.bppm_features = BPPFeatures(bppm_path / "index.ind", bppm_path / "joined.mmap")

def __len__(self):
return len(self.seq)

def _process_seq(self, rawseq):
seq = [self.seq_map['START']]
start_loc = 0
seq.extend(self.seq_map[s] for s in rawseq)
seq.append(self.seq_map['END'])
end_loc = len(seq) - 1
for i in range(len(seq), self.Lmax+2):
seq.append(self.seq_map['EMPTY'])

seq = np.array(seq)
seq = torch.from_numpy(seq)

return seq, start_loc, end_loc

def __getitem__(self, idx):
seq = self.seq[idx]
real_seq_L = len(seq)

lbord = 1
rbord = self.Lmax + 1 - real_seq_L

seq_int, start_loc, end_loc = self._process_seq(seq)
mask = torch.zeros(self.Lmax + 2, dtype=torch.bool)
mask[start_loc+1:end_loc] = True # not including START and END
conv_bpp_mask = torch.zeros(self.Lmax + 2, self.Lmax + 2, dtype=torch.bool)
conv_bpp_mask[start_loc+1:end_loc, start_loc+1:end_loc] = True # not including START and END

forward_mask = torch.zeros(self.Lmax + 2, dtype=torch.bool) # START, seq, END
forward_mask[start_loc:end_loc+1] = True # including START and END


react = np.stack([self.react_2A3[idx][:real_seq_L],
self.react_DMS[idx][:real_seq_L]],
-1)
react = np.pad(react, ((lbord,
rbord),
(0,0)), constant_values=np.nan)

react = torch.from_numpy(react)


X = {'seq_int': seq_int,
'mask': mask,
'forward_mask': forward_mask,
'conv_bpp_mask': conv_bpp_mask,
'is_good': self.is_good[idx]}

sid = self.sid[idx]

if self._use_bppm:
adj = self.bppm_features[sid][:real_seq_L, :real_seq_L]
adj = np.pad(adj, ((lbord,rbord), (lbord, rbord)), constant_values=0)
adj = torch.from_numpy(adj).float()
X['adj'] = adj


y = {'react': react.float(),
'mask': mask}


return X, y


class RNA_Dataset_Test(Dataset):
def __init__(self,
df: pd.DataFrame,
use_bppm: bool = False,
bppm_path: Optional[Path] = None
):
self.df = df
self.seq_map = {'A':0,
'C':1,
'G':2,
'U':3,
"START": 4,
"END": 5,
"EMPTY": 6}
df['L'] = df.sequence.apply(len)
self.Lmax = df['L'].max()
self.sid = df.sequence_id
self._use_bppm = use_bppm
self._bppm_path = bppm_path
if use_bppm and bppm_path is None:
raise ValueError("If use_bppm is set True, bppm_path must be specified.")

def __len__(self):
return len(self.df)

def _process_seq(self, rawseq):
seq = [self.seq_map['START']]
start_loc = 0
seq.extend(self.seq_map[s] for s in rawseq)
seq.append(self.seq_map['END'])
end_loc = len(seq) - 1
for i in range(len(seq), self.Lmax+2):
seq.append(self.seq_map['EMPTY'])

seq = np.array(seq)
seq = torch.from_numpy(seq)

return seq, start_loc, end_loc

def __getitem__(self, idx):
id_min, id_max, seq = self.df.loc[idx, ['id_min','id_max','sequence']]
L = len(seq)

ids = np.arange(id_min,id_max+1)
ids = np.pad(ids,(1,self.Lmax+1-L), constant_values=-1)


seq_int, start_loc, end_loc = self._process_seq(seq)
mask = torch.zeros(self.Lmax + 2, dtype=torch.bool)
mask[start_loc+1:end_loc] = True # not including START and END

conv_bpp_mask = torch.zeros(self.Lmax + 2, self.Lmax + 2, dtype=torch.bool)
conv_bpp_mask[start_loc+1:end_loc, start_loc+1:end_loc] = True # not including START and END

forward_mask = torch.zeros(self.Lmax + 2, dtype=torch.bool) # START, seq, END
forward_mask[start_loc:end_loc+1] = True # including START and END



X = {'seq_int': seq_int,
'mask': mask,
"is_good":1,
"forward_mask": forward_mask,
'conv_bpp_mask': conv_bpp_mask}

sid = self.sid[idx]

if self._use_bppm:
adj = _load_bppm(self.sid[idx],
self.Lmax,
self._bppm_path)
adj = np.pad(adj, ((1,1), (1, 1)), constant_values=0)
X['adj'] = torch.from_numpy(adj).float()

return X, {'ids':ids}
Loading

0 comments on commit 0ba3e90

Please sign in to comment.