Skip to content

Commit

Permalink
2023/08/30-12:27:07 (Linux sv2111 unknown)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Aug 30, 2023
1 parent 734c064 commit 47ce3f9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
20 changes: 13 additions & 7 deletions coordinationnet/model_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@ def __init__(self,
# Transformer options
edim = 200,
# **kwargs contains options for dense layers
layers = [200, 512, 128, 1], **kwargs):
layers = [512, 128, 1], **kwargs):

super().__init__()

print(f'{model_config}')

# Dimension of the graph features
fdim = 2*edim

# The model config determines which components of the model
# are active
self.model_config = model_config
Expand All @@ -47,28 +50,31 @@ def __init__(self,

# Embeddings
self.embedding_element = ElementEmbedder(edim, from_pretrained=True, freeze=False)
self.embedding_ligands = ElementEmbedder(edim, from_pretrained=True, freeze=False)
self.embedding_ces = torch.nn.Embedding(NumGeometries+1, edim)

self.layers = Sequential('x, edge_index, batch', [
(GCNConv(edim, edim), 'x, edge_index -> x'),
(GCNConv(fdim, fdim), 'x, edge_index -> x'),
torch.nn.ELU(inplace=True),
(GCNConv(edim, edim), 'x, edge_index -> x'),
(GCNConv(fdim, fdim), 'x, edge_index -> x'),
torch.nn.ELU(inplace=True),
(global_mean_pool, 'x, batch -> x'),
])

# Final dense layer
self.dense = ModelDense([edim] + layers, **kwargs)
self.dense = ModelDense([fdim] + layers, **kwargs)

print(f'Creating a GNN model with {self.n_parameters:,} parameters')

def forward(self, x_input):

x_elements = self.embedding_element(x_input.x['elements'])
x_elements = self.embedding_element(x_input.x['elements'])
x_oxidations = self.embedding_element(x_input.x['oxidations'])

x = torch.cat((x_elements, x_oxidations), dim=1)

edge_index = x_input.edge_index

x = self.layers(x_elements, edge_index, x_input.batch)
x = self.layers(x, edge_index, x_input.batch)
x = self.dense(x)

return x
Expand Down
2 changes: 1 addition & 1 deletion coordinationnet/model_gnn_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, dataset) -> None:
X = [ item[0] for item in dataset]
y = [ item[1] for item in dataset]

X = self.__compute_graphs__(X, verbose=False)
X = self.__compute_graphs__(X, verbose=True)

super().__init__(X, y)

Expand Down

0 comments on commit 47ce3f9

Please sign in to comment.