Skip to content

Commit

Permalink
Add two protein datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
limei0307 committed Aug 31, 2022
1 parent f84a23d commit 3667821
Show file tree
Hide file tree
Showing 5 changed files with 513 additions and 2 deletions.
249 changes: 249 additions & 0 deletions dig/threedgraph/dataset/ECdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import os.path as osp
import h5py
import numpy as np
import warnings
from tqdm import tqdm

import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset


class ECdataset(InMemoryDataset):
def __init__(self,
root,
transform=None,
pre_transform=None,
pre_filter=None,
split='train'
):

self.split = split
self.root = root

super(ECdataset, self).__init__(
root, transform, pre_transform, pre_filter)

self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def processed_dir(self):
name = 'processed'
return osp.join(self.root, name, self.split)

@property
def raw_file_names(self):
name = self.split + '.txt'
return name

@property
def processed_file_names(self):
return 'data.pt'


def _normalize(self,tensor, dim=-1):
'''
Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
'''
return torch.nan_to_num(
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))

def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos):
# atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1
mask_n = np.char.equal(atom_names, b'N')
mask_ca = np.char.equal(atom_names, b'CA')
mask_c = np.char.equal(atom_names, b'C')
mask_cb = np.char.equal(atom_names, b'CB')
mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1')
mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1')
mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1')
mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ')
mask_h = np.char.equal(atom_names, b'NH1')

pos_n = np.full((len(amino_types),3),np.nan)
pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n]
pos_n = torch.FloatTensor(pos_n)

pos_ca = np.full((len(amino_types),3),np.nan)
pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca]
pos_ca = torch.FloatTensor(pos_ca)

pos_c = np.full((len(amino_types),3),np.nan)
pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c]
pos_c = torch.FloatTensor(pos_c)

# if data only contain pos_ca, we set the position of C and N as the position of CA
pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)]
pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)]

pos_cb = np.full((len(amino_types),3),np.nan)
pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb]
pos_cb = torch.FloatTensor(pos_cb)

pos_g = np.full((len(amino_types),3),np.nan)
pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g]
pos_g = torch.FloatTensor(pos_g)

pos_d = np.full((len(amino_types),3),np.nan)
pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d]
pos_d = torch.FloatTensor(pos_d)

pos_e = np.full((len(amino_types),3),np.nan)
pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e]
pos_e = torch.FloatTensor(pos_e)

pos_z = np.full((len(amino_types),3),np.nan)
pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z]
pos_z = torch.FloatTensor(pos_z)

pos_h = np.full((len(amino_types),3),np.nan)
pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h]
pos_h = torch.FloatTensor(pos_h)

return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h


def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h):
v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z

# five side chain torsion angles
# We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0.
angle1 = torch.unsqueeze(self.compute_diherals(v1, v2, v3),1)
angle2 = torch.unsqueeze(self.compute_diherals(v2, v3, v4),1)
angle3 = torch.unsqueeze(self.compute_diherals(v3, v4, v5),1)
angle4 = torch.unsqueeze(self.compute_diherals(v4, v5, v6),1)
angle5 = torch.unsqueeze(self.compute_diherals(v5, v6, v7),1)

side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1)
side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1)

return side_chain_embs


def bb_embs(self, X):
# X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue
# N coords: X[:,0,:]
# CA coords: X[:,1,:]
# C coords: X[:,2,:]
# return num_residues x 6
# From https://github.com/jingraham/neurips19-graph-protein-design

X = torch.reshape(X, [3 * X.shape[0], 3])
dX = X[1:] - X[:-1]
U = self._normalize(dX, dim=-1)
u0 = U[:-2]
u1 = U[1:-1]
u2 = U[2:]

angle = self.compute_diherals(u0, u1, u2)

# add phi[0], psi[-1], omega[-1] with value 0
angle = F.pad(angle, [1, 2])
angle = torch.reshape(angle, [-1, 3])
angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1)
return angle_features


def compute_diherals(self, v1, v2, v3):
n1 = torch.cross(v1, v2)
n2 = torch.cross(v2, v3)
a = (n1 * n2).sum(dim=-1)
b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1))
torsion = torch.nan_to_num(torch.atan2(b, a))
return torsion


def protein_to_graph(self, pFilePath):
h5File = h5py.File(pFilePath, "r")
data = Data()

amino_types = h5File['amino_types'][()] # size: (n_amino,)
mask = amino_types == -1
if np.sum(mask) > 0:
amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25
atom_amino_id = h5File['atom_amino_id'][()] # size: (n_atom,)
atom_names = h5File['atom_names'][()] # size: (n_atom,)
atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3)

# atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1
pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos)

# five side chain torsion angles
# We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0.
side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h)
side_chain_embs[torch.isnan(side_chain_embs)] = 0
data.side_chain_embs = side_chain_embs

# three backbone torsion angles
bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1))
bb_embs[torch.isnan(bb_embs)] = 0
data.bb_embs = bb_embs

data.x = torch.unsqueeze(torch.tensor(amino_types),1)
data.coords_ca = pos_ca
data.coords_n = pos_n
data.coords_c = pos_c

assert len(data.x)==len(data.coords_ca)==len(data.coords_n)==len(data.coords_c)==len(data.side_chain_embs)==len(data.bb_embs)

h5File.close()
return data


def process(self):
print('Beginning Processing ...')

# Load the file with the list of functions.
functions_ = []
with open(self.root+"/unique_functions.txt", 'r') as mFile:
for line in mFile:
functions_.append(line.rstrip())

# Get the file list.
if self.split == "Train":
splitFile = "/training.txt"
elif self.split == "Val":
splitFile = "/validation.txt"
elif self.split == "Test":
splitFile = "/testing.txt"

proteinNames_ = []
fileList_ = []
with open(self.root+splitFile, 'r') as mFile:
for line in mFile:
proteinNames_.append(line.rstrip())
fileList_.append(self.root+"/data/"+line.rstrip())

# Load the functions.
print("Reading protein functions")
protFunct_ = {}
with open(self.root+"/chain_functions.txt", 'r') as mFile:
for line in mFile:
splitLine = line.rstrip().split(',')
if splitLine[0] in proteinNames_:
protFunct_[splitLine[0]] = int(splitLine[1])

# Load the dataset
print("Reading the data")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
data_list = []
for fileIter, curFile in tqdm(enumerate(fileList_)):
fileName = curFile.split('/')[-1]
curProtein = self.protein_to_graph(curFile+".hdf5")
curProtein.id = fileName
curProtein.y = torch.tensor(protFunct_[proteinNames_[fileIter]])
if not curProtein.x is None:
data_list.append(curProtein)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
print('Done!')

if __name__ == "__main__":
for split in ['Train', 'Val', 'Test']:
print('#### Now processing {} data ####'.format(split))
dataset = ECdataset(root='path', split=split)
print(dataset)
Loading

0 comments on commit 3667821

Please sign in to comment.