Skip to content

Commit 38fe25c

Browse files
committed
fix: Address additional CodeRabbit review comments
- Fix CUDA device mismatch in LSTMAggregator random permutation - Clean up imports and use proper typing (DefaultDict) in memory_store - Fix edge counting for undirected graphs (divide by 2) - Improve shared memory allocation with guards and memory limits - Ensure deterministic neighbor ordering with sorted() for reproducibility These changes improve stability, memory management, and reproducibility of the GraphSAGE implementation.
1 parent 79c70be commit 38fe25c

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

core/framework/graph_embedders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def forward(self, self_feats: torch.Tensor, neighbor_feats: torch.Tensor) -> tor
172172
"""
173173
# Random permutation of neighbors for LSTM
174174
batch_size, num_neighbors, _ = neighbor_feats.shape
175-
perm = torch.randperm(num_neighbors)
175+
perm = torch.randperm(num_neighbors, device=neighbor_feats.device)
176176
neighbor_feats = neighbor_feats[:, perm, :]
177177

178178
# LSTM aggregation

core/framework/memory_store.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
import os
1313
import numpy as np
14-
from typing import Dict, List, Set, Tuple, Optional, Any
14+
from typing import Dict, List, Set, Tuple, Optional, Any, DefaultDict
1515
from dataclasses import dataclass
1616
from collections import defaultdict
17-
import multiprocessing as mp
1817
from multiprocessing import shared_memory
1918
import json
2019
import time
@@ -55,7 +54,7 @@ def __init__(self, max_memory_gb: float = 100.0):
5554
self.node_embeddings: Optional[np.ndarray] = None # Shape: (n_nodes, embedding_dim)
5655

5756
# Edge data (adjacency lists for efficiency)
58-
self.adjacency: Dict[int, Set[int]] = defaultdict(set)
57+
self.adjacency: DefaultDict[int, Set[int]] = defaultdict(set)
5958
self.edge_types: Dict[Tuple[int, int], str] = {}
6059

6160
# Reverse mappings
@@ -114,9 +113,13 @@ def load_from_arangodb(self, db_config: Dict[str, Any]) -> GraphStats:
114113
load_time = time.time() - start_time
115114
memory_usage = self._calculate_memory_usage()
116115

116+
# For undirected graphs, edges are counted twice in adjacency lists
117+
# Divide by 2 to get the actual edge count
118+
edge_count = sum(len(neighbors) for neighbors in self.adjacency.values()) // 2
119+
117120
self.stats = GraphStats(
118121
num_nodes=len(self.node_ids),
119-
num_edges=sum(len(neighbors) for neighbors in self.adjacency.values()),
122+
num_edges=edge_count,
120123
num_node_types=len(set(self.node_types.values())),
121124
num_edge_types=len(set(self.edge_types.values())),
122125
memory_usage_gb=memory_usage,
@@ -205,24 +208,46 @@ def create_shared_memory(self, embedding_dim: int = 2048):
205208
"""
206209
num_nodes = len(self.node_ids)
207210

208-
# Calculate size needed
209-
size = num_nodes * embedding_dim * np.float32().itemsize
211+
# Guard against zero-node graphs
212+
if num_nodes == 0:
213+
print("Warning: No nodes in graph, skipping shared memory creation")
214+
return
210215

211-
# Create shared memory
212-
self.shared_memory = shared_memory.SharedMemory(create=True, size=size)
213-
self.shared_memory_name = self.shared_memory.name
216+
# Clean up any existing shared memory
217+
if self.shared_memory is not None:
218+
self.cleanup()
214219

215-
# Create numpy array backed by shared memory
216-
self.node_embeddings = np.ndarray(
217-
(num_nodes, embedding_dim),
218-
dtype=np.float32,
219-
buffer=self.shared_memory.buf
220-
)
220+
# Calculate size needed
221+
size = num_nodes * embedding_dim * np.float32().itemsize
222+
size_gb = size / (1024 ** 3)
221223

222-
# Initialize with zeros (will be filled by GraphSAGE)
223-
self.node_embeddings[:] = 0
224+
# Check memory limit
225+
if size_gb > self.max_memory_gb:
226+
raise MemoryError(
227+
f"Required memory ({size_gb:.2f} GB) exceeds limit ({self.max_memory_gb} GB). "
228+
f"Reduce embedding_dim or increase max_memory_gb."
229+
)
224230

225-
print(f"Created shared memory '{self.shared_memory_name}' for embeddings")
231+
try:
232+
# Create shared memory
233+
self.shared_memory = shared_memory.SharedMemory(create=True, size=size)
234+
self.shared_memory_name = self.shared_memory.name
235+
236+
# Create numpy array backed by shared memory
237+
self.node_embeddings = np.ndarray(
238+
(num_nodes, embedding_dim),
239+
dtype=np.float32,
240+
buffer=self.shared_memory.buf
241+
)
242+
243+
# Initialize with zeros (will be filled by GraphSAGE)
244+
self.node_embeddings[:] = 0
245+
246+
print(f"Created shared memory '{self.shared_memory_name}' for embeddings ({size_gb:.2f} GB)")
247+
248+
except Exception as e:
249+
self.cleanup()
250+
raise RuntimeError(f"Failed to create shared memory: {e}")
226251

227252
def get_neighbors(self, node_index: int, max_neighbors: Optional[int] = None) -> List[int]:
228253
"""
@@ -235,7 +260,8 @@ def get_neighbors(self, node_index: int, max_neighbors: Optional[int] = None) ->
235260
Returns:
236261
List of neighbor indices
237262
"""
238-
neighbors = list(self.adjacency.get(node_index, set()))
263+
# Use sorted() for deterministic ordering
264+
neighbors = sorted(list(self.adjacency.get(node_index, set())))
239265

240266
if max_neighbors and len(neighbors) > max_neighbors:
241267
# Random sampling for scalability

tools/graphsage/utils/neighborhood_sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _compute_sampling_probabilities(self) -> Optional[Dict[int, np.ndarray]]:
7070
probs = {}
7171

7272
for node_idx in range(len(self.graph_store.node_ids)):
73-
neighbors = list(self.graph_store.adjacency.get(node_idx, set()))
73+
neighbors = sorted(list(self.graph_store.adjacency.get(node_idx, set())))
7474

7575
if not neighbors:
7676
continue
@@ -107,7 +107,8 @@ def sample_neighbors(self, node: int, num_samples: int) -> List[int]:
107107
Returns:
108108
List of sampled neighbor indices
109109
"""
110-
neighbors = list(self.graph_store.adjacency.get(node, set()))
110+
# Use sorted() for deterministic ordering
111+
neighbors = sorted(list(self.graph_store.adjacency.get(node, set())))
111112

112113
if not neighbors:
113114
return [node] # Self-loop if no neighbors

0 commit comments

Comments
 (0)