-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathdata_process.py
115 lines (99 loc) · 4.42 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import pickle
import lmdb
import torch
from tqdm.auto import tqdm
import os.path as osp
from utils.chem import read_pkl
from utils.transforms import *
from utils.misc import *
from utils.surface import read_ply
from utils.protein_ligand import parse_sdf_file, parse_rdmol
from utils.data import ProteinLigandData, torchify_dict
from utils.surface import geodesic_matrix, dst2knnedge, parse_face, gds_edge_process
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./configs/train.yml')
parser.add_argument('--index_path', type=str, default='./data/crossdock_data/index.pkl',
help='a list storing each pair information, including surf_file, lig_file, protein_file')
parser.add_argument('--processed_path', type=str, default='./data/crossdock_data_gds.lmdb',
help='the path to store the processed data')
parser.add_argument('--name2id_path', type=str, default='./data/crossdock_gdsname2id.pt',
help='the path to store the name2id dict, which is used to split the dataset according to the split_name.pt')
parser.add_argument('--surf_path', type=str, default='/home/haotian/Molecule_Generation/SurfGen/data/crossdock2020_surface_8',
help='the path storing surface files')
parser.add_argument('--lig_path', type=str, default='/home/haotian/Molecule_Generation/SurfGen/data/crossdocked_pocket10',
help='the path storing ligand files')
args = parser.parse_args()
config = load_config(args.config)
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)
edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)
transform = Compose([
RefineData(),
LigandCountNeighbors(),
protein_featurizer,
ligand_featurizer,
masking,
composer,
FocalBuilder(),
edge_sampler,
contrastive_sampler,
])
index = read_pkl(parser.index_path)
db = lmdb.open(
args.processed_path,
map_size=50*(1024*1024*1024), # 10GB
create=True,
subdir=False,
readonly=False, # Writable
)
num_skipped = 0
import time
start=time.time()
with db.begin(write=True, buffers=True) as txn:
for i, (pocket_nm, ligand_nm, protein_nm,_) in enumerate(tqdm(index)):
if pocket_nm is None:
continue
try:
surf_nm = pocket_nm[:-6]+'_8.0_res_1.5.ply'
sdf_file = osp.join(args.lig_path,ligand_nm)
ply_file = osp.join(args.surf_path,surf_nm)
pocket_dict = read_ply(ply_file)
ligand_dict = parse_sdf_file(sdf_file)
data = ProteinLigandData.from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand_dict),
)
data.pocket_filename = pocket_nm
data.protein_filename = protein_nm
data.ligand_filename = ligand_nm
data.surface_filename = surf_nm
data.face = parse_face(ply_file) # read_ply_geom(ply_file,read_face=True).face
edge_index = torch.cat([data.face[:2], data.face[1:], data.face[::2]], dim=1)
dlny_edge_index = to_undirected(edge_index, num_nodes=data.protein_pos.shape[0])
gds_mat = geodesic_matrix(data.protein_pos, dlny_edge_index)
gds_knn_edge_index, gds_knn_edge_dist = dst2knnedge(gds_mat, num_knn=16)
data.gds_knn_edge_index = gds_knn_edge_index
data.gds_dist = gds_knn_edge_dist
txn.put(
key = str(i).encode(),
value = pickle.dumps(data)
)
except Exception as e:
print(e)
num_skipped += 1
if num_skipped%100 == 0:
print('Skipping (%d) %s' % (num_skipped, ligand_nm, ))
db.close()
print('finished, {}'.format(time.time()-start))
# create name2id if not exists
SurfLigandPairDataset(index_path=args.index_path,
processed_path=args.processed_path,
name2id=args.name2id_path, transform=transform)