diff --git a/coordinationnet/model_gnn.py b/coordinationnet/model_gnn.py index af3957a..9254e61 100644 --- a/coordinationnet/model_gnn.py +++ b/coordinationnet/model_gnn.py @@ -33,12 +33,15 @@ def __init__(self, # Transformer options edim = 200, # **kwargs contains options for dense layers - layers = [200, 512, 128, 1], **kwargs): + layers = [512, 128, 1], **kwargs): super().__init__() print(f'{model_config}') + # Dimension of the graph features + fdim = 2*edim + # The model config determines which components of the model # are active self.model_config = model_config @@ -47,28 +50,31 @@ def __init__(self, # Embeddings self.embedding_element = ElementEmbedder(edim, from_pretrained=True, freeze=False) - self.embedding_ligands = ElementEmbedder(edim, from_pretrained=True, freeze=False) self.embedding_ces = torch.nn.Embedding(NumGeometries+1, edim) self.layers = Sequential('x, edge_index, batch', [ - (GCNConv(edim, edim), 'x, edge_index -> x'), + (GCNConv(fdim, fdim), 'x, edge_index -> x'), torch.nn.ELU(inplace=True), - (GCNConv(edim, edim), 'x, edge_index -> x'), + (GCNConv(fdim, fdim), 'x, edge_index -> x'), torch.nn.ELU(inplace=True), (global_mean_pool, 'x, batch -> x'), ]) # Final dense layer - self.dense = ModelDense([edim] + layers, **kwargs) + self.dense = ModelDense([fdim] + layers, **kwargs) print(f'Creating a GNN model with {self.n_parameters:,} parameters') def forward(self, x_input): - x_elements = self.embedding_element(x_input.x['elements']) + x_elements = self.embedding_element(x_input.x['elements']) + x_oxidations = self.embedding_element(x_input.x['oxidations']) + + x = torch.cat((x_elements, x_oxidations), dim=1) + edge_index = x_input.edge_index - x = self.layers(x_elements, edge_index, x_input.batch) + x = self.layers(x, edge_index, x_input.batch) x = self.dense(x) return x diff --git a/coordinationnet/model_gnn_data.py b/coordinationnet/model_gnn_data.py index 5e8e24f..a016617 100644 --- a/coordinationnet/model_gnn_data.py +++ b/coordinationnet/model_gnn_data.py @@ -38,7 +38,7 @@ def __init__(self, dataset) -> None: X = [ item[0] for item in dataset] y = [ item[1] for item in dataset] - X = self.__compute_graphs__(X, verbose=False) + X = self.__compute_graphs__(X, verbose=True) super().__init__(X, y)