diff --git a/coordinationnet/model_gnn.py b/coordinationnet/model_gnn.py index 06613f0..74871ad 100644 --- a/coordinationnet/model_gnn.py +++ b/coordinationnet/model_gnn.py @@ -16,8 +16,12 @@ import torch -from torch_geometric.data import Data -from torch_geometric.nn import Sequential, GraphConv, HeteroConv, global_mean_pool +from typing import Union + +from torch_geometric.data import Data +from torch_geometric.nn import Sequential, GraphConv, CGConv, HeteroConv, global_mean_pool +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.typing import Adj, OptTensor, PairTensor from .features_coding import NumOxidations, NumGeometries @@ -26,6 +30,17 @@ ## ---------------------------------------------------------------------------- +class IdConv(MessagePassing): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x: Union[torch.Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None) -> torch.Tensor: + + return x + +## ---------------------------------------------------------------------------- + class ModelGraphCoordinationNet(torch.nn.Module): def __init__(self, # Specify model components @@ -48,7 +63,7 @@ def __init__(self, dim_ligand = dim_element + dim_oxidation if model_config['distances']: - dim_ce += dim_distance + dim_ligand += dim_distance if model_config['angles']: dim_ligand += dim_angle @@ -60,38 +75,39 @@ def __init__(self, self.scaler_outputs = TorchStandardScaler(layers[-1]) # RBF encoder - self.rbf_csm = RBFEmbedding(0.0, 1.0, bins=model_config['bins_csm'], edim=dim_csm) - self.rbf_distances = RBFEmbedding(0.0, 1.0, bins=model_config['bins_distance'], edim=dim_distance) - self.rbf_angles = RBFEmbedding(0.0, 1.0, bins=model_config['bins_angle'], edim=dim_angle) + self.rbf_csm = RBFEmbedding(0.0, 1.0, bins=model_config['bins_csm'] , edim=dim_csm) + self.rbf_distances_1 = RBFEmbedding(0.0, 1.0, bins=model_config['bins_distance'], edim=dim_distance) + self.rbf_distances_2 = RBFEmbedding(0.0, 1.0, bins=model_config['bins_distance'], edim=dim_distance) + self.rbf_angles = RBFEmbedding(0.0, 1.0, bins=model_config['bins_angle'] , edim=dim_angle) # Embeddings - self.embedding_element = ElementEmbedder(edim, from_pretrained=True, freeze=True) - self.embedding_oxidation = torch.nn.Embedding(NumOxidations, 10) - self.embedding_geometry = PaddedEmbedder(NumGeometries, 10) + self.embedding_element = ElementEmbedder(dim_element, from_pretrained=True, freeze=True) + self.embedding_oxidation = torch.nn.Embedding(NumOxidations, dim_oxidation) + self.embedding_geometry = PaddedEmbedder(NumGeometries, dim_geometry) self.activation = torch.nn.ELU(inplace=True) # Core graph network - self.layers = Sequential('x, edge_index, batch', [ + self.layers = Sequential('x, edge_index, edge_attr, batch', [ # Layer 1 ----------------------------------------------------------------------------------- (HeteroConv({ - ('site' , '*', 'site' ): GraphConv((dim_site, dim_site), dim_site , add_self_loops=False), - ('ligand', '*', 'ce' ): GraphConv((dim_ligand, dim_ce), dim_ce , add_self_loops=True ), - ('ce' , '*', 'ligand'): GraphConv((dim_ce, dim_ligand), dim_ligand, add_self_loops=True ), - }), 'x, edge_index -> x'), - # Apply activation - (lambda x: { k : self.activation(v) for k, v in x.items()}, 'x -> x'), + ('site' , '*', 'site' ): IdConv(), + ('ligand', '*', 'ce' ): CGConv((dim_ligand, dim_ce), dim_distance, add_self_loops=True), + ('ce' , '*', 'ligand'): CGConv((dim_ce, dim_ligand), dim_distance, add_self_loops=True), + }), 'x, edge_index, edge_attr -> x'), + # Apply activation, except for site nodes + (lambda x: { k : v if k == 'site' else self.activation(v) for k, v in x.items()}, 'x -> x'), # Layer 2 ----------------------------------------------------------------------------------- (HeteroConv({ - ('site' , '*', 'site' ): GraphConv((dim_site, dim_site), dim_site , add_self_loops=False), - ('ligand', '*', 'ce' ): GraphConv((dim_ligand, dim_ce), dim_ce , add_self_loops=True ), - ('ce' , '*', 'ligand'): GraphConv((dim_ce, dim_ligand), dim_ligand, add_self_loops=True ), - }), 'x, edge_index -> x'), - # Apply activation - (lambda x: { k : self.activation(v) for k, v in x.items()}, 'x -> x'), + ('site' , '*', 'site' ): IdConv(), + ('ligand', '*', 'ce' ): CGConv((dim_ligand, dim_ce), dim_distance, add_self_loops=True), + ('ce' , '*', 'ligand'): CGConv((dim_ce, dim_ligand), dim_distance, add_self_loops=True), + }), 'x, edge_index, edge_attr -> x'), + # Apply activation, except for site nodes + (lambda x: { k : v if k == 'site' else self.activation(v) for k, v in x.items()}, 'x -> x'), # Layer 4 ----------------------------------------------------------------------------------- (HeteroConv({ - ('ce', '*', 'site' ): GraphConv((dim_ce, dim_site ), dim_site , add_self_loops=True, bias=False), + ('ce', '*', 'site'): GraphConv((dim_ce, dim_site), dim_site, add_self_loops=True, bias=False), }, aggr='mean'), 'x, edge_index -> x'), # Apply activation (lambda x: { k : self.activation(v) for k, v in x.items()}, 'x -> x'), @@ -126,9 +142,9 @@ def forward(self, x_input): # Add optional features if self.model_config['distances']: - x_ce = torch.cat(( - x_ce, - self.rbf_distances(x_input['ce'].x['distances']), + x_ligand = torch.cat(( + x_ligand, + self.rbf_distances_1(x_input['ligand'].x['distances']), ), dim=1) if self.model_config['angles']: @@ -141,8 +157,12 @@ def forward(self, x_input): x = { 'site': x_site, 'ce': x_ce, 'ligand': x_ligand, } + edge_attr_dict = { + ('ligand', '*', 'ce'): self.rbf_distances_2(x_input['ligand', '*', 'ce'].edge_attr), + ('ce', '*', 'ligand'): self.rbf_distances_2(x_input['ce', '*', 'ligand'].edge_attr), + } # Propagate features through graph network - x = self.layers(x, x_input.edge_index_dict, x_input.batch_dict) + x = self.layers(x, x_input.edge_index_dict, edge_attr_dict, x_input.batch_dict) # Apply final dense layer x = self.dense(x) # Apply inverse transformation diff --git a/coordinationnet/model_gnn_data.py b/coordinationnet/model_gnn_data.py index a94e5b3..c135675 100644 --- a/coordinationnet/model_gnn_data.py +++ b/coordinationnet/model_gnn_data.py @@ -39,7 +39,7 @@ def code_csms(csms) -> list[float]: def code_distance(distance : float, l : int) -> torch.Tensor: # Sites `from` and `to` get distance assigned, all ligands # get inf - x = torch.tensor(2*[distance], dtype=torch.float) / 8.0 + x = torch.tensor(l*[distance], dtype=torch.float) / 8.0 return x def code_angles(angles : list[float]) -> torch.Tensor: @@ -48,6 +48,14 @@ def code_angles(angles : list[float]) -> torch.Tensor: x = torch.tensor(angles, dtype=torch.float) / 180 return x +def get_distance(features, site, site_to): + for item in features.distances: + if item['site'] == site and item['site_to'] == site_to: + return item['distance'] + if item['site_to'] == site and item['site'] == site_to: + return item['distance'] + raise RuntimeError('Distance not available') + ## ---------------------------------------------------------------------------- class GraphCoordinationData(GenericDataset): @@ -101,17 +109,19 @@ def __compute_graph_ce_pairs__(cls, features : CoordinationFeatures, data : Hete 'oxidations': torch.tensor([], dtype=torch.long), 'geometries': torch.tensor([], dtype=torch.long), 'csms' : torch.tensor([], dtype=torch.float), - 'distances' : torch.tensor([], dtype=torch.float), } x_ligand = { 'elements' : torch.tensor([], dtype=torch.long), 'oxidations': torch.tensor([], dtype=torch.long), + 'distances' : torch.tensor([], dtype=torch.float), 'angles' : torch.tensor([], dtype=torch.float), } # Edges - e1 = [[], []] - e2 = [[], []] - e3 = [[], []] + edge_index_1 = [[], []] + edge_index_2 = [[], []] + edge_index_3 = [[], []] + # Edge features + edge_attr = [] # Global node index i i1 = 0 i2 = 0 @@ -126,24 +136,27 @@ def __compute_graph_ce_pairs__(cls, features : CoordinationFeatures, data : Hete x_ce['elements' ] = torch.cat((x_ce['elements' ], torch.tensor([ features.sites.elements [site] for site in idx_ce ], dtype=torch.long))) x_ce['oxidations'] = torch.cat((x_ce['oxidations'], torch.tensor([ features.sites.oxidations[site] for site in idx_ce ], dtype=torch.long))) x_ce['geometries'] = torch.cat((x_ce['geometries'], torch.tensor([ site_ces[site] for site in idx_ce ], dtype=torch.long ))) - x_ce['distances' ] = torch.cat((x_ce['distances' ], code_distance(nb['distance'], l))) x_ce['csms' ] = torch.cat((x_ce['csms' ], code_csms([ site_csm[site] for site in idx_ce ]))) # Construct ligand features x_ligand['elements' ] = torch.cat((x_ligand['elements' ], torch.tensor([ features.sites.elements [site] for site in idx_ligand ], dtype=torch.long))) x_ligand['oxidations'] = torch.cat((x_ligand['oxidations'], torch.tensor([ features.sites.oxidations[site] for site in idx_ligand ], dtype=torch.long))) - x_ligand['angles' ] = torch.cat((x_ligand['angles' ], code_angles(nb['angles']))) + x_ligand['distances' ] = torch.cat((x_ligand['distances' ], code_distance(nb['distance'], l))) + x_ligand['angles' ] = torch.cat((x_ligand['angles' ], code_angles (nb['angles']))) - for j, _ in enumerate(nb['ligand_indices']): + for j, k in enumerate(nb['ligand_indices']): # From ligand ; To CE - e1[0].append(i2+j); e1[1].append(i1+0) - e1[0].append(i2+j); e1[1].append(i1+1) + edge_index_1[0].append(i2+j); edge_index_1[1].append(i1+0) + edge_index_1[0].append(i2+j); edge_index_1[1].append(i1+1) # From CE ; To ligand - e2[0].append(i1+0); e2[1].append(i2+j) - e2[0].append(i1+1); e2[1].append(i2+j) + edge_index_2[0].append(i1+0); edge_index_2[1].append(i2+j) + edge_index_2[0].append(i1+1); edge_index_2[1].append(i2+j) + # ligand-CE features + edge_attr.append(get_distance(features, idx_ce[0], k)) + edge_attr.append(get_distance(features, idx_ce[1], k)) # Connect CE nodes to site nodes - e3[0].append(i1+0); e3[1].append(nb['site']) - e3[0].append(i1+1); e3[1].append(nb['site_to']) + edge_index_3[0].append(i1+0); edge_index_3[1].append(nb['site']) + edge_index_3[0].append(i1+1); edge_index_3[1].append(nb['site_to']) i1 += 2 i2 += len(nb['ligand_indices']) @@ -155,10 +168,13 @@ def __compute_graph_ce_pairs__(cls, features : CoordinationFeatures, data : Hete data['ce' ].x = x_ce data['ligand'].x = x_ligand # Assign edges - data['ligand', '*', 'ce'].edge_index = torch.tensor(e1, dtype=torch.long) - data['ce', '*', 'ligand'].edge_index = torch.tensor(e2, dtype=torch.long) + data['ligand', '*', 'ce'].edge_index = torch.tensor(edge_index_1, dtype=torch.long) + data['ce', '*', 'ligand'].edge_index = torch.tensor(edge_index_2, dtype=torch.long) + # Assign edge features + data['ligand', '*', 'ce'].edge_attr = torch.tensor(edge_attr, dtype=torch.long) + data['ce', '*', 'ligand'].edge_attr = torch.tensor(edge_attr, dtype=torch.long) # Connect CE nodes to site nodes - data['ce', '*', 'site'].edge_index = torch.tensor(e3, dtype=torch.long) + data['ce', '*', 'site'].edge_index = torch.tensor(edge_index_3, dtype=torch.long) @classmethod def __compute_graph__(cls, features : CoordinationFeatures) -> GraphData: