Skip to content

Commit

Permalink
2023/10/01-18:13:46 (Linux cray unknown)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Oct 1, 2023
1 parent 69ea2b0 commit e1d3819
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 44 deletions.
74 changes: 47 additions & 27 deletions coordinationnet/model_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

import torch

from torch_geometric.data import Data
from torch_geometric.nn import Sequential, GraphConv, HeteroConv, global_mean_pool
from typing import Union

from torch_geometric.data import Data
from torch_geometric.nn import Sequential, GraphConv, CGConv, HeteroConv, global_mean_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor, PairTensor

from .features_coding import NumOxidations, NumGeometries

Expand All @@ -26,6 +30,17 @@

## ----------------------------------------------------------------------------

class IdConv(MessagePassing):

def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, x: Union[torch.Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None) -> torch.Tensor:

return x

## ----------------------------------------------------------------------------

class ModelGraphCoordinationNet(torch.nn.Module):
def __init__(self,
# Specify model components
Expand All @@ -48,7 +63,7 @@ def __init__(self,
dim_ligand = dim_element + dim_oxidation

if model_config['distances']:
dim_ce += dim_distance
dim_ligand += dim_distance

if model_config['angles']:
dim_ligand += dim_angle
Expand All @@ -60,38 +75,39 @@ def __init__(self,
self.scaler_outputs = TorchStandardScaler(layers[-1])

# RBF encoder
self.rbf_csm = RBFEmbedding(0.0, 1.0, bins=model_config['bins_csm'], edim=dim_csm)
self.rbf_distances = RBFEmbedding(0.0, 1.0, bins=model_config['bins_distance'], edim=dim_distance)
self.rbf_angles = RBFEmbedding(0.0, 1.0, bins=model_config['bins_angle'], edim=dim_angle)
self.rbf_csm = RBFEmbedding(0.0, 1.0, bins=model_config['bins_csm'] , edim=dim_csm)
self.rbf_distances_1 = RBFEmbedding(0.0, 1.0, bins=model_config['bins_distance'], edim=dim_distance)
self.rbf_distances_2 = RBFEmbedding(0.0, 1.0, bins=model_config['bins_distance'], edim=dim_distance)
self.rbf_angles = RBFEmbedding(0.0, 1.0, bins=model_config['bins_angle'] , edim=dim_angle)

# Embeddings
self.embedding_element = ElementEmbedder(edim, from_pretrained=True, freeze=True)
self.embedding_oxidation = torch.nn.Embedding(NumOxidations, 10)
self.embedding_geometry = PaddedEmbedder(NumGeometries, 10)
self.embedding_element = ElementEmbedder(dim_element, from_pretrained=True, freeze=True)
self.embedding_oxidation = torch.nn.Embedding(NumOxidations, dim_oxidation)
self.embedding_geometry = PaddedEmbedder(NumGeometries, dim_geometry)

self.activation = torch.nn.ELU(inplace=True)

# Core graph network
self.layers = Sequential('x, edge_index, batch', [
self.layers = Sequential('x, edge_index, edge_attr, batch', [
# Layer 1 -----------------------------------------------------------------------------------
(HeteroConv({
('site' , '*', 'site' ): GraphConv((dim_site, dim_site), dim_site , add_self_loops=False),
('ligand', '*', 'ce' ): GraphConv((dim_ligand, dim_ce), dim_ce , add_self_loops=True ),
('ce' , '*', 'ligand'): GraphConv((dim_ce, dim_ligand), dim_ligand, add_self_loops=True ),
}), 'x, edge_index -> x'),
# Apply activation
(lambda x: { k : self.activation(v) for k, v in x.items()}, 'x -> x'),
('site' , '*', 'site' ): IdConv(),
('ligand', '*', 'ce' ): CGConv((dim_ligand, dim_ce), dim_distance, add_self_loops=True),
('ce' , '*', 'ligand'): CGConv((dim_ce, dim_ligand), dim_distance, add_self_loops=True),
}), 'x, edge_index, edge_attr -> x'),
# Apply activation, except for site nodes
(lambda x: { k : v if k == 'site' else self.activation(v) for k, v in x.items()}, 'x -> x'),
# Layer 2 -----------------------------------------------------------------------------------
(HeteroConv({
('site' , '*', 'site' ): GraphConv((dim_site, dim_site), dim_site , add_self_loops=False),
('ligand', '*', 'ce' ): GraphConv((dim_ligand, dim_ce), dim_ce , add_self_loops=True ),
('ce' , '*', 'ligand'): GraphConv((dim_ce, dim_ligand), dim_ligand, add_self_loops=True ),
}), 'x, edge_index -> x'),
# Apply activation
(lambda x: { k : self.activation(v) for k, v in x.items()}, 'x -> x'),
('site' , '*', 'site' ): IdConv(),
('ligand', '*', 'ce' ): CGConv((dim_ligand, dim_ce), dim_distance, add_self_loops=True),
('ce' , '*', 'ligand'): CGConv((dim_ce, dim_ligand), dim_distance, add_self_loops=True),
}), 'x, edge_index, edge_attr -> x'),
# Apply activation, except for site nodes
(lambda x: { k : v if k == 'site' else self.activation(v) for k, v in x.items()}, 'x -> x'),
# Layer 4 -----------------------------------------------------------------------------------
(HeteroConv({
('ce', '*', 'site' ): GraphConv((dim_ce, dim_site ), dim_site , add_self_loops=True, bias=False),
('ce', '*', 'site'): GraphConv((dim_ce, dim_site), dim_site, add_self_loops=True, bias=False),
}, aggr='mean'), 'x, edge_index -> x'),
# Apply activation
(lambda x: { k : self.activation(v) for k, v in x.items()}, 'x -> x'),
Expand Down Expand Up @@ -126,9 +142,9 @@ def forward(self, x_input):

# Add optional features
if self.model_config['distances']:
x_ce = torch.cat((
x_ce,
self.rbf_distances(x_input['ce'].x['distances']),
x_ligand = torch.cat((
x_ligand,
self.rbf_distances_1(x_input['ligand'].x['distances']),
), dim=1)

if self.model_config['angles']:
Expand All @@ -141,8 +157,12 @@ def forward(self, x_input):
x = {
'site': x_site, 'ce': x_ce, 'ligand': x_ligand,
}
edge_attr_dict = {
('ligand', '*', 'ce'): self.rbf_distances_2(x_input['ligand', '*', 'ce'].edge_attr),
('ce', '*', 'ligand'): self.rbf_distances_2(x_input['ce', '*', 'ligand'].edge_attr),
}
# Propagate features through graph network
x = self.layers(x, x_input.edge_index_dict, x_input.batch_dict)
x = self.layers(x, x_input.edge_index_dict, edge_attr_dict, x_input.batch_dict)
# Apply final dense layer
x = self.dense(x)
# Apply inverse transformation
Expand Down
50 changes: 33 additions & 17 deletions coordinationnet/model_gnn_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def code_csms(csms) -> list[float]:
def code_distance(distance : float, l : int) -> torch.Tensor:
# Sites `from` and `to` get distance assigned, all ligands
# get inf
x = torch.tensor(2*[distance], dtype=torch.float) / 8.0
x = torch.tensor(l*[distance], dtype=torch.float) / 8.0
return x

def code_angles(angles : list[float]) -> torch.Tensor:
Expand All @@ -48,6 +48,14 @@ def code_angles(angles : list[float]) -> torch.Tensor:
x = torch.tensor(angles, dtype=torch.float) / 180
return x

def get_distance(features, site, site_to):
for item in features.distances:
if item['site'] == site and item['site_to'] == site_to:
return item['distance']
if item['site_to'] == site and item['site'] == site_to:
return item['distance']
raise RuntimeError('Distance not available')

## ----------------------------------------------------------------------------

class GraphCoordinationData(GenericDataset):
Expand Down Expand Up @@ -101,17 +109,19 @@ def __compute_graph_ce_pairs__(cls, features : CoordinationFeatures, data : Hete
'oxidations': torch.tensor([], dtype=torch.long),
'geometries': torch.tensor([], dtype=torch.long),
'csms' : torch.tensor([], dtype=torch.float),
'distances' : torch.tensor([], dtype=torch.float),
}
x_ligand = {
'elements' : torch.tensor([], dtype=torch.long),
'oxidations': torch.tensor([], dtype=torch.long),
'distances' : torch.tensor([], dtype=torch.float),
'angles' : torch.tensor([], dtype=torch.float),
}
# Edges
e1 = [[], []]
e2 = [[], []]
e3 = [[], []]
edge_index_1 = [[], []]
edge_index_2 = [[], []]
edge_index_3 = [[], []]
# Edge features
edge_attr = []
# Global node index i
i1 = 0
i2 = 0
Expand All @@ -126,24 +136,27 @@ def __compute_graph_ce_pairs__(cls, features : CoordinationFeatures, data : Hete
x_ce['elements' ] = torch.cat((x_ce['elements' ], torch.tensor([ features.sites.elements [site] for site in idx_ce ], dtype=torch.long)))
x_ce['oxidations'] = torch.cat((x_ce['oxidations'], torch.tensor([ features.sites.oxidations[site] for site in idx_ce ], dtype=torch.long)))
x_ce['geometries'] = torch.cat((x_ce['geometries'], torch.tensor([ site_ces[site] for site in idx_ce ], dtype=torch.long )))
x_ce['distances' ] = torch.cat((x_ce['distances' ], code_distance(nb['distance'], l)))
x_ce['csms' ] = torch.cat((x_ce['csms' ], code_csms([ site_csm[site] for site in idx_ce ])))
# Construct ligand features
x_ligand['elements' ] = torch.cat((x_ligand['elements' ], torch.tensor([ features.sites.elements [site] for site in idx_ligand ], dtype=torch.long)))
x_ligand['oxidations'] = torch.cat((x_ligand['oxidations'], torch.tensor([ features.sites.oxidations[site] for site in idx_ligand ], dtype=torch.long)))
x_ligand['angles' ] = torch.cat((x_ligand['angles' ], code_angles(nb['angles'])))
x_ligand['distances' ] = torch.cat((x_ligand['distances' ], code_distance(nb['distance'], l)))
x_ligand['angles' ] = torch.cat((x_ligand['angles' ], code_angles (nb['angles'])))

for j, _ in enumerate(nb['ligand_indices']):
for j, k in enumerate(nb['ligand_indices']):
# From ligand ; To CE
e1[0].append(i2+j); e1[1].append(i1+0)
e1[0].append(i2+j); e1[1].append(i1+1)
edge_index_1[0].append(i2+j); edge_index_1[1].append(i1+0)
edge_index_1[0].append(i2+j); edge_index_1[1].append(i1+1)
# From CE ; To ligand
e2[0].append(i1+0); e2[1].append(i2+j)
e2[0].append(i1+1); e2[1].append(i2+j)
edge_index_2[0].append(i1+0); edge_index_2[1].append(i2+j)
edge_index_2[0].append(i1+1); edge_index_2[1].append(i2+j)
# ligand-CE features
edge_attr.append(get_distance(features, idx_ce[0], k))
edge_attr.append(get_distance(features, idx_ce[1], k))

# Connect CE nodes to site nodes
e3[0].append(i1+0); e3[1].append(nb['site'])
e3[0].append(i1+1); e3[1].append(nb['site_to'])
edge_index_3[0].append(i1+0); edge_index_3[1].append(nb['site'])
edge_index_3[0].append(i1+1); edge_index_3[1].append(nb['site_to'])

i1 += 2
i2 += len(nb['ligand_indices'])
Expand All @@ -155,10 +168,13 @@ def __compute_graph_ce_pairs__(cls, features : CoordinationFeatures, data : Hete
data['ce' ].x = x_ce
data['ligand'].x = x_ligand
# Assign edges
data['ligand', '*', 'ce'].edge_index = torch.tensor(e1, dtype=torch.long)
data['ce', '*', 'ligand'].edge_index = torch.tensor(e2, dtype=torch.long)
data['ligand', '*', 'ce'].edge_index = torch.tensor(edge_index_1, dtype=torch.long)
data['ce', '*', 'ligand'].edge_index = torch.tensor(edge_index_2, dtype=torch.long)
# Assign edge features
data['ligand', '*', 'ce'].edge_attr = torch.tensor(edge_attr, dtype=torch.long)
data['ce', '*', 'ligand'].edge_attr = torch.tensor(edge_attr, dtype=torch.long)
# Connect CE nodes to site nodes
data['ce', '*', 'site'].edge_index = torch.tensor(e3, dtype=torch.long)
data['ce', '*', 'site'].edge_index = torch.tensor(edge_index_3, dtype=torch.long)

@classmethod
def __compute_graph__(cls, features : CoordinationFeatures) -> GraphData:
Expand Down

0 comments on commit e1d3819

Please sign in to comment.