From 734c0648a636e46e4ba0dd15f45db5a35339d061 Mon Sep 17 00:00:00 2001 From: Philipp Benner Date: Wed, 30 Aug 2023 12:16:06 +0200 Subject: [PATCH] 2023/08/30-12:16:06 (Linux sv2111 unknown) --- coordinationnet/model_gnn_data.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/coordinationnet/model_gnn_data.py b/coordinationnet/model_gnn_data.py index b631326..5e8e24f 100644 --- a/coordinationnet/model_gnn_data.py +++ b/coordinationnet/model_gnn_data.py @@ -44,19 +44,20 @@ def __init__(self, dataset) -> None: @classmethod def __compute_graph__(cls, features : CoordinationFeatures) -> GraphData: + # Initialize graph with isolated nodes for each site x = { - 'elements' : torch.empty((0,), dtype=torch.long), - 'oxidations': torch.empty((0,), dtype=torch.long), - #'ces' : torch.empty((0,), dtype=torch.long), - #'csm' : torch.empty((0,), dtype=torch.float), - #'angles' : torch.empty((0,), dtype=torch.float), - #'distance' : torch.empty((0,), dtype=torch.float), + 'elements' : torch.tensor(features.sites.elements , dtype=torch.long), + 'oxidations': torch.tensor(features.sites.oxidations, dtype=torch.long), } - # Global edge index - e = [[], []] + # Global edge index, initialize with self-connections for + # isolated nodes + e = [[ i for i, _ in enumerate(features.sites.elements) ], + [ i for i, _ in enumerate(features.sites.elements) ]] + # Global node index + i = len(features.sites.elements) # Some materials do not have CE pairs if len(features.ce_neighbors) == 0: - return GraphData(x=x, edge_index=torch.empty((2,0), dtype=torch.long), num_nodes=0) + return GraphData(x=x, edge_index=torch.tensor(e, dtype=torch.long), num_nodes=i) # Get CE symbols and CSMs site_ces = len(features.sites.elements)*[NumGeometries] site_csm = len(features.sites.elements)*[0.0] @@ -70,8 +71,6 @@ def __compute_graph__(cls, features : CoordinationFeatures) -> GraphData: # Consider only the first CE symbol site_ces[j] = ce['ce_symbols'][0] site_csm[j] = ce['csms'][0] - # Global node index - i = 0 # Construct CE graphs for nb in features.ce_neighbors: l = len(nb['ligand_indices'])