-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
attentive_fp.py
148 lines (120 loc) · 5.28 KB
/
attentive_fp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os.path as osp
from math import sqrt
import torch
import torch.nn.functional as F
from rdkit import Chem
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import AttentiveFP
class GenFeatures:
def __init__(self):
self.symbols = [
'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',
'Te', 'I', 'At', 'other'
]
self.hybridizations = [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
'other',
]
self.stereos = [
Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
]
def __call__(self, data):
# Generate AttentiveFP features according to Table 1.
mol = Chem.MolFromSmiles(data.smiles)
xs = []
for atom in mol.GetAtoms():
symbol = [0.] * len(self.symbols)
symbol[self.symbols.index(atom.GetSymbol())] = 1.
degree = [0.] * 6
degree[atom.GetDegree()] = 1.
formal_charge = atom.GetFormalCharge()
radical_electrons = atom.GetNumRadicalElectrons()
hybridization = [0.] * len(self.hybridizations)
hybridization[self.hybridizations.index(
atom.GetHybridization())] = 1.
aromaticity = 1. if atom.GetIsAromatic() else 0.
hydrogens = [0.] * 5
hydrogens[atom.GetTotalNumHs()] = 1.
chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
chirality_type = [0.] * 2
if atom.HasProp('_CIPCode'):
chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.
x = torch.tensor(symbol + degree + [formal_charge] +
[radical_electrons] + hybridization +
[aromaticity] + hydrogens + [chirality] +
chirality_type)
xs.append(x)
data.x = torch.stack(xs, dim=0)
edge_indices = []
edge_attrs = []
for bond in mol.GetBonds():
edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]
bond_type = bond.GetBondType()
single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
conjugation = 1. if bond.GetIsConjugated() else 0.
ring = 1. if bond.IsInRing() else 0.
stereo = [0.] * 4
stereo[self.stereos.index(bond.GetStereo())] = 1.
edge_attr = torch.tensor(
[single, double, triple, aromatic, conjugation, ring] + stereo)
edge_attrs += [edge_attr, edge_attr]
if len(edge_attrs) == 0:
data.edge_index = torch.zeros((2, 0), dtype=torch.long)
data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
else:
data.edge_index = torch.tensor(edge_indices).t().contiguous()
data.edge_attr = torch.stack(edge_attrs, dim=0)
return data
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'AFP_Mol')
dataset = MoleculeNet(path, name='ESOL', pre_transform=GenFeatures()).shuffle()
N = len(dataset) // 10
val_dataset = dataset[:N]
test_dataset = dataset[N:2 * N]
train_dataset = dataset[2 * N:]
train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=200)
test_loader = DataLoader(test_dataset, batch_size=200)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1,
edge_dim=10, num_layers=2, num_timesteps=2,
dropout=0.2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5,
weight_decay=10**-5)
def train():
total_loss = total_examples = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.edge_attr, data.batch)
loss = F.mse_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += float(loss) * data.num_graphs
total_examples += data.num_graphs
return sqrt(total_loss / total_examples)
@torch.no_grad()
def test(loader):
mse = []
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.edge_attr, data.batch)
mse.append(F.mse_loss(out, data.y, reduction='none').cpu())
return float(torch.cat(mse, dim=0).mean().sqrt())
for epoch in range(1, 201):
train_rmse = train()
val_rmse = test(val_loader)
test_rmse = test(test_loader)
print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
f'Test: {test_rmse:.4f}')