From 3667821bd76a1b211b4583f83653ecfb1e40af7f Mon Sep 17 00:00:00 2001 From: limei0307 Date: Wed, 31 Aug 2022 13:44:58 -0500 Subject: [PATCH] Add two protein datasets --- dig/threedgraph/dataset/ECdataset.py | 249 +++++++++++++++++++++++++ dig/threedgraph/dataset/FOLDdataset.py | 236 +++++++++++++++++++++++ dig/threedgraph/dataset/README.md | 20 ++ dig/threedgraph/dataset/__init__.py | 6 +- dig/threedgraph/method/__init__.py | 4 +- 5 files changed, 513 insertions(+), 2 deletions(-) create mode 100644 dig/threedgraph/dataset/ECdataset.py create mode 100644 dig/threedgraph/dataset/FOLDdataset.py create mode 100644 dig/threedgraph/dataset/README.md diff --git a/dig/threedgraph/dataset/ECdataset.py b/dig/threedgraph/dataset/ECdataset.py new file mode 100644 index 00000000..071445fa --- /dev/null +++ b/dig/threedgraph/dataset/ECdataset.py @@ -0,0 +1,249 @@ +import os.path as osp +import h5py +import numpy as np +import warnings +from tqdm import tqdm + +import torch +import torch.nn.functional as F + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class ECdataset(InMemoryDataset): + def __init__(self, + root, + transform=None, + pre_transform=None, + pre_filter=None, + split='train' + ): + + self.split = split + self.root = root + + super(ECdataset, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + name = 'processed' + return osp.join(self.root, name, self.split) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + + def _normalize(self,tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + mask_n = np.char.equal(atom_names, b'N') + mask_ca = np.char.equal(atom_names, b'CA') + mask_c = np.char.equal(atom_names, b'C') + mask_cb = np.char.equal(atom_names, b'CB') + mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1') + mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1') + mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1') + mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ') + mask_h = np.char.equal(atom_names, b'NH1') + + pos_n = np.full((len(amino_types),3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(amino_types),3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(amino_types),3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(amino_types),3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(amino_types),3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(amino_types),3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(amino_types),3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(amino_types),3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(amino_types),3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h + + + def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.compute_diherals(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.compute_diherals(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.compute_diherals(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.compute_diherals(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.compute_diherals(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + + def bb_embs(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.compute_diherals(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + + def compute_diherals(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + return torsion + + + def protein_to_graph(self, pFilePath): + h5File = h5py.File(pFilePath, "r") + data = Data() + + amino_types = h5File['amino_types'][()] # size: (n_amino,) + mask = amino_types == -1 + if np.sum(mask) > 0: + amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25 + atom_amino_id = h5File['atom_amino_id'][()] # size: (n_atom,) + atom_names = h5File['atom_names'][()] # size: (n_atom,) + atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3) + + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos) + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_embs[torch.isnan(side_chain_embs)] = 0 + data.side_chain_embs = side_chain_embs + + # three backbone torsion angles + bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + bb_embs[torch.isnan(bb_embs)] = 0 + data.bb_embs = bb_embs + + data.x = torch.unsqueeze(torch.tensor(amino_types),1) + data.coords_ca = pos_ca + data.coords_n = pos_n + data.coords_c = pos_c + + assert len(data.x)==len(data.coords_ca)==len(data.coords_n)==len(data.coords_c)==len(data.side_chain_embs)==len(data.bb_embs) + + h5File.close() + return data + + + def process(self): + print('Beginning Processing ...') + + # Load the file with the list of functions. + functions_ = [] + with open(self.root+"/unique_functions.txt", 'r') as mFile: + for line in mFile: + functions_.append(line.rstrip()) + + # Get the file list. + if self.split == "Train": + splitFile = "/training.txt" + elif self.split == "Val": + splitFile = "/validation.txt" + elif self.split == "Test": + splitFile = "/testing.txt" + + proteinNames_ = [] + fileList_ = [] + with open(self.root+splitFile, 'r') as mFile: + for line in mFile: + proteinNames_.append(line.rstrip()) + fileList_.append(self.root+"/data/"+line.rstrip()) + + # Load the functions. + print("Reading protein functions") + protFunct_ = {} + with open(self.root+"/chain_functions.txt", 'r') as mFile: + for line in mFile: + splitLine = line.rstrip().split(',') + if splitLine[0] in proteinNames_: + protFunct_[splitLine[0]] = int(splitLine[1]) + + # Load the dataset + print("Reading the data") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data_list = [] + for fileIter, curFile in tqdm(enumerate(fileList_)): + fileName = curFile.split('/')[-1] + curProtein = self.protein_to_graph(curFile+".hdf5") + curProtein.id = fileName + curProtein.y = torch.tensor(protFunct_[proteinNames_[fileIter]]) + if not curProtein.x is None: + data_list.append(curProtein) + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + +if __name__ == "__main__": + for split in ['Train', 'Val', 'Test']: + print('#### Now processing {} data ####'.format(split)) + dataset = ECdataset(root='path', split=split) + print(dataset) \ No newline at end of file diff --git a/dig/threedgraph/dataset/FOLDdataset.py b/dig/threedgraph/dataset/FOLDdataset.py new file mode 100644 index 00000000..aca184c5 --- /dev/null +++ b/dig/threedgraph/dataset/FOLDdataset.py @@ -0,0 +1,236 @@ +import os.path as osp +import h5py +import numpy as np +import warnings +from tqdm import tqdm + +import torch +import torch.nn.functional as F + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class FOLDdataset(InMemoryDataset): + def __init__(self, + root, + transform=None, + pre_transform=None, + pre_filter=None, + split='train' + ): + + self.split = split + self.root = root + + super(FOLDdataset, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + name = 'processed' + return osp.join(self.root, name, self.split) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + + def _normalize(self,tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + mask_n = np.char.equal(atom_names, b'N') + mask_ca = np.char.equal(atom_names, b'CA') + mask_c = np.char.equal(atom_names, b'C') + mask_cb = np.char.equal(atom_names, b'CB') + mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1') + mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1') + mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1') + mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ') + mask_h = np.char.equal(atom_names, b'NH1') + + pos_n = np.full((len(amino_types),3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(amino_types),3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(amino_types),3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(amino_types),3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(amino_types),3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(amino_types),3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(amino_types),3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(amino_types),3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(amino_types),3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h + + + def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.compute_diherals(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.compute_diherals(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.compute_diherals(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.compute_diherals(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.compute_diherals(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + + def bb_embs(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.compute_diherals(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + + def compute_diherals(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + return torsion + + + def protein_to_graph(self, pFilePath): + h5File = h5py.File(pFilePath, "r") + data = Data() + + amino_types = h5File['amino_types'][()] # size: (n_amino,) + mask = amino_types == -1 + if np.sum(mask) > 0: + amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25 + atom_amino_id = h5File['atom_amino_id'][()] # size: (n_atom,) + atom_names = h5File['atom_names'][()] # size: (n_atom,) + atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3) + + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos) + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_embs[torch.isnan(side_chain_embs)] = 0 + data.side_chain_embs = side_chain_embs + + # three backbone torsion angles + bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + bb_embs[torch.isnan(bb_embs)] = 0 + data.bb_embs = bb_embs + + data.x = torch.unsqueeze(torch.tensor(amino_types),1) + data.coords_ca = pos_ca + data.coords_n = pos_n + data.coords_c = pos_c + + assert len(data.x)==len(data.coords_ca)==len(data.coords_n)==len(data.coords_c)==len(data.side_chain_embs)==len(data.bb_embs) + + h5File.close() + return data + + + def process(self): + print('Beginning Processing ...') + + # Load the file with the list of functions. + classes_ = {} + with open(self.root+"/class_map.txt", 'r') as mFile: + for line in mFile: + lineList = line.rstrip().split('\t') + classes_[lineList[0]] = int(lineList[1]) + + # Get the file list. + fileList_ = [] + cathegories_ = [] + with open(self.root+"/"+self.split+".txt", 'r') as mFile: + for curLine in mFile: + splitLine = curLine.rstrip().split('\t') + curClass = classes_[splitLine[-1]] + fileList_.append(self.root+"/"+self.split+"/"+splitLine[0]) + cathegories_.append(curClass) + + # Load the dataset + print("Reading the data") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data_list = [] + for fileIter, curFile in tqdm(enumerate(fileList_)): + fileName = curFile.split('/')[-1] + curProtein = self.protein_to_graph(curFile+".hdf5") + curProtein.id = fileName + curProtein.y = torch.tensor(cathegories_[fileIter]) + if not curProtein.x is None: + data_list.append(curProtein) + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + +if __name__ == "__main__": + for split in ['training', 'validation', 'test_fold', 'test_superfamily', 'test_family']: + print('#### Now processing {} data ####'.format(split)) + dataset = FOLDdataset(root='path', split=split) + print(dataset) \ No newline at end of file diff --git a/dig/threedgraph/dataset/README.md b/dig/threedgraph/dataset/README.md new file mode 100644 index 00000000..35d45548 --- /dev/null +++ b/dig/threedgraph/dataset/README.md @@ -0,0 +1,20 @@ +# Datasets + +## ECdataset and FOLDdataset + +For ECdataset and FOLDdatset, please download datasets from [here](https://github.com/phermosilla/IEConv_proteins#download-the-preprocessed-datasets) to a path. The set the parameter `root='path'` to load and process the data. + +Usage example: +```python +# ECdataset +for split in ['Train', 'Val', 'Test']: + print('#### Now processing {} data ####'.format(split)) + dataset = ECdataset(root='path', split=split) + print(dataset) + +# FOLDdataset +for split in ['training', 'validation', 'test_fold', 'test_superfamily', 'test_family']: + print('#### Now processing {} data ####'.format(split)) + dataset = FOLDdataset(root='path', split=split) + print(dataset) +``` \ No newline at end of file diff --git a/dig/threedgraph/dataset/__init__.py b/dig/threedgraph/dataset/__init__.py index 7693502b..31ee2836 100644 --- a/dig/threedgraph/dataset/__init__.py +++ b/dig/threedgraph/dataset/__init__.py @@ -1,7 +1,11 @@ from .PygQM93D import QM93D from .PygMD17 import MD17 +from .ECdataset import ECdataset +from .FOLDdataset import FOLDdataset __all__ = [ 'QM93D', - 'MD17' + 'MD17', + 'ECdataset', + 'FOLDdataset' ] \ No newline at end of file diff --git a/dig/threedgraph/method/__init__.py b/dig/threedgraph/method/__init__.py index eef22c36..2480d760 100644 --- a/dig/threedgraph/method/__init__.py +++ b/dig/threedgraph/method/__init__.py @@ -3,6 +3,7 @@ from .dimenetpp import DimeNetPP from .spherenet import SphereNet from .comenet import ComENet +from .pronet import ProNet __all__ = [ @@ -10,5 +11,6 @@ 'SchNet', 'DimeNetPP', 'SphereNet', - 'ComENet' + 'ComENet', + 'ProNet' ] \ No newline at end of file