1111
1212import os
1313import numpy as np
14- from typing import Dict , List , Set , Tuple , Optional , Any
14+ from typing import Dict , List , Set , Tuple , Optional , Any , DefaultDict
1515from dataclasses import dataclass
1616from collections import defaultdict
17- import multiprocessing as mp
1817from multiprocessing import shared_memory
1918import json
2019import time
@@ -55,7 +54,7 @@ def __init__(self, max_memory_gb: float = 100.0):
5554 self .node_embeddings : Optional [np .ndarray ] = None # Shape: (n_nodes, embedding_dim)
5655
5756 # Edge data (adjacency lists for efficiency)
58- self .adjacency : Dict [int , Set [int ]] = defaultdict (set )
57+ self .adjacency : DefaultDict [int , Set [int ]] = defaultdict (set )
5958 self .edge_types : Dict [Tuple [int , int ], str ] = {}
6059
6160 # Reverse mappings
@@ -114,9 +113,13 @@ def load_from_arangodb(self, db_config: Dict[str, Any]) -> GraphStats:
114113 load_time = time .time () - start_time
115114 memory_usage = self ._calculate_memory_usage ()
116115
116+ # For undirected graphs, edges are counted twice in adjacency lists
117+ # Divide by 2 to get the actual edge count
118+ edge_count = sum (len (neighbors ) for neighbors in self .adjacency .values ()) // 2
119+
117120 self .stats = GraphStats (
118121 num_nodes = len (self .node_ids ),
119- num_edges = sum ( len ( neighbors ) for neighbors in self . adjacency . values ()) ,
122+ num_edges = edge_count ,
120123 num_node_types = len (set (self .node_types .values ())),
121124 num_edge_types = len (set (self .edge_types .values ())),
122125 memory_usage_gb = memory_usage ,
@@ -205,24 +208,46 @@ def create_shared_memory(self, embedding_dim: int = 2048):
205208 """
206209 num_nodes = len (self .node_ids )
207210
208- # Calculate size needed
209- size = num_nodes * embedding_dim * np .float32 ().itemsize
211+ # Guard against zero-node graphs
212+ if num_nodes == 0 :
213+ print ("Warning: No nodes in graph, skipping shared memory creation" )
214+ return
210215
211- # Create shared memory
212- self .shared_memory = shared_memory . SharedMemory ( create = True , size = size )
213- self . shared_memory_name = self .shared_memory . name
216+ # Clean up any existing shared memory
217+ if self .shared_memory is not None :
218+ self .cleanup ()
214219
215- # Create numpy array backed by shared memory
216- self .node_embeddings = np .ndarray (
217- (num_nodes , embedding_dim ),
218- dtype = np .float32 ,
219- buffer = self .shared_memory .buf
220- )
220+ # Calculate size needed
221+ size = num_nodes * embedding_dim * np .float32 ().itemsize
222+ size_gb = size / (1024 ** 3 )
221223
222- # Initialize with zeros (will be filled by GraphSAGE)
223- self .node_embeddings [:] = 0
224+ # Check memory limit
225+ if size_gb > self .max_memory_gb :
226+ raise MemoryError (
227+ f"Required memory ({ size_gb :.2f} GB) exceeds limit ({ self .max_memory_gb } GB). "
228+ f"Reduce embedding_dim or increase max_memory_gb."
229+ )
224230
225- print (f"Created shared memory '{ self .shared_memory_name } ' for embeddings" )
231+ try :
232+ # Create shared memory
233+ self .shared_memory = shared_memory .SharedMemory (create = True , size = size )
234+ self .shared_memory_name = self .shared_memory .name
235+
236+ # Create numpy array backed by shared memory
237+ self .node_embeddings = np .ndarray (
238+ (num_nodes , embedding_dim ),
239+ dtype = np .float32 ,
240+ buffer = self .shared_memory .buf
241+ )
242+
243+ # Initialize with zeros (will be filled by GraphSAGE)
244+ self .node_embeddings [:] = 0
245+
246+ print (f"Created shared memory '{ self .shared_memory_name } ' for embeddings ({ size_gb :.2f} GB)" )
247+
248+ except Exception as e :
249+ self .cleanup ()
250+ raise RuntimeError (f"Failed to create shared memory: { e } " )
226251
227252 def get_neighbors (self , node_index : int , max_neighbors : Optional [int ] = None ) -> List [int ]:
228253 """
@@ -235,7 +260,8 @@ def get_neighbors(self, node_index: int, max_neighbors: Optional[int] = None) ->
235260 Returns:
236261 List of neighbor indices
237262 """
238- neighbors = list (self .adjacency .get (node_index , set ()))
263+ # Use sorted() for deterministic ordering
264+ neighbors = sorted (list (self .adjacency .get (node_index , set ())))
239265
240266 if max_neighbors and len (neighbors ) > max_neighbors :
241267 # Random sampling for scalability
0 commit comments