Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions plugins/neo4j_graph/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
_BACKFILL_LAST_CHECK: dict = {}
_BACKFILL_CHECK_INTERVAL = 30 # Seconds between retry attempts when Qdrant is empty

# Optional PageRank refresh throttling (PageRank is expensive; default is OFF)
_PAGERANK_LAST_RUN: dict[str, float] = {}
_PAGERANK_LOCK = threading.Lock()

# Environment variable to disable auto-backfill (enabled by default)
AUTO_BACKFILL_DISABLED = os.environ.get("NEO4J_AUTO_BACKFILL_DISABLE", "").strip().lower() in {"1", "true", "yes", "on"}

Expand Down Expand Up @@ -409,7 +413,11 @@ async def run_query_async(
except ImportError:
# Fallback: run sync query in thread pool
import asyncio
loop = asyncio.get_event_loop()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Defensive fallback for edge cases (should not happen inside an async fn)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
self._run_query_sync,
Expand Down Expand Up @@ -445,6 +453,13 @@ def ensure_graph_store(self, base_collection: str) -> Optional[str]:
db = self._get_database()

if db in self._initialized_databases:
# Indexes already exist, but auto-backfill is per-collection.
# Ensure we still perform the backfill check for new collections.
if base_collection:
try:
self._check_auto_backfill(base_collection)
except Exception as e:
logger.debug(f"Auto-backfill check failed: {e}")
return base_collection or db

try:
Expand Down Expand Up @@ -960,12 +975,25 @@ def upsert_edges(
logger.error(f"Failed to upsert Neo4j INHERITS_FROM edges batch: {e}")

# Compute simple degree-based importance scores for new nodes
# This avoids requiring Neo4j GDS (Graph Data Science) library
if total > 0:
try:
self._compute_simple_pagerank(collection)
except Exception as e:
logger.warning(f"Failed to compute pagerank: {e}")
# NOTE: Full PageRank/in-degree sweeps are expensive and can dominate ingest time.
# Keep this OFF by default; run `compute_pagerank()` explicitly after indexing,
# or enable `NEO4J_PAGERANK_ON_UPSERT=1` for periodic refreshes.
if total > 0 and os.environ.get("NEO4J_PAGERANK_ON_UPSERT", "").strip().lower() in {"1", "true", "yes", "on"}:
import time as _time
min_interval = float(os.environ.get("NEO4J_PAGERANK_ON_UPSERT_MIN_INTERVAL", "300") or 300)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min_interval = float(os.environ.get(...)) will raise ValueError if NEO4J_PAGERANK_ON_UPSERT_MIN_INTERVAL is set to a non-numeric string, which would abort edge upserts; consider guarding the parse and falling back to a default.

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎

now = _time.time()
should_run = True
with _PAGERANK_LOCK:
last = _PAGERANK_LAST_RUN.get(collection, 0.0)
if (now - last) < min_interval:
should_run = False
else:
_PAGERANK_LAST_RUN[collection] = now
if should_run:
try:
self._compute_simple_pagerank(collection)
except Exception as e:
logger.warning(f"Failed to compute pagerank: {e}")

return total

Expand Down Expand Up @@ -1210,13 +1238,15 @@ def get_callers(
collection = self._get_collection(graph_store)
# Pattern for class + methods: exact match OR starts with "Class."
symbol_prefix = f"{symbol}."
# Suffix match for symbols stored with full module paths (e.g., "pkg.mod.MyClass")
symbol_suffix = f".{symbol}"

try:
with driver.session(database=db) as session:
if repo and repo != "*":
result = session.run("""
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (callee.name = $symbol OR callee.name STARTS WITH $symbol_prefix)
WHERE (callee.name = $symbol OR callee.name STARTS WITH $symbol_prefix OR callee.name ENDS WITH $symbol_suffix)
AND r.collection = $collection AND (r.repo = $repo OR callee.repo = $repo)
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
Expand All @@ -1228,11 +1258,11 @@ def get_callers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "repo": repo, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "symbol_suffix": symbol_suffix, "repo": repo, "collection": collection, "limit": limit})
else:
result = session.run("""
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (callee.name = $symbol OR callee.name STARTS WITH $symbol_prefix)
WHERE (callee.name = $symbol OR callee.name STARTS WITH $symbol_prefix OR callee.name ENDS WITH $symbol_suffix)
AND r.collection = $collection
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
Expand All @@ -1244,7 +1274,7 @@ def get_callers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "symbol_suffix": symbol_suffix, "collection": collection, "limit": limit})

return [dict(record) for record in result]

Expand All @@ -1269,13 +1299,15 @@ def get_callees(
collection = self._get_collection(graph_store)
# Pattern for class + methods: exact match OR starts with "Class."
symbol_prefix = f"{symbol}."
# Suffix match for symbols stored with full module paths (e.g., "pkg.mod.func")
symbol_suffix = f".{symbol}"

try:
with driver.session(database=db) as session:
if repo and repo != "*":
result = session.run("""
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (caller.name = $symbol OR caller.name STARTS WITH $symbol_prefix)
WHERE (caller.name = $symbol OR caller.name STARTS WITH $symbol_prefix OR caller.name ENDS WITH $symbol_suffix)
AND r.collection = $collection AND (r.repo = $repo OR caller.repo = $repo)
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
Expand All @@ -1288,11 +1320,11 @@ def get_callees(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "repo": repo, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "symbol_suffix": symbol_suffix, "repo": repo, "collection": collection, "limit": limit})
else:
result = session.run("""
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (caller.name = $symbol OR caller.name STARTS WITH $symbol_prefix)
WHERE (caller.name = $symbol OR caller.name STARTS WITH $symbol_prefix OR caller.name ENDS WITH $symbol_suffix)
AND r.collection = $collection
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
Expand All @@ -1305,7 +1337,7 @@ def get_callees(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "symbol_suffix": symbol_suffix, "collection": collection, "limit": limit})

return [dict(record) for record in result]

Expand All @@ -1330,13 +1362,15 @@ def get_importers(
collection = self._get_collection(graph_store)
# Pattern for module + submodules: exact match OR starts with "module."
module_prefix = f"{module}."
# Suffix match for fully-qualified imports (e.g., "pkg.sub.mod")
module_suffix = f".{module}"

try:
with driver.session(database=db) as session:
if repo and repo != "*":
result = session.run("""
MATCH (importer:Symbol {collection: $collection})-[r:IMPORTS]->(imported:Symbol {collection: $collection})
WHERE (imported.name = $module OR imported.name STARTS WITH $module_prefix)
WHERE (imported.name = $module OR imported.name STARTS WITH $module_prefix OR imported.name ENDS WITH $module_suffix)
AND r.collection = $collection AND (r.repo = $repo OR imported.repo = $repo)
RETURN importer.name as caller_symbol,
imported.name as callee_symbol,
Expand All @@ -1346,11 +1380,11 @@ def get_importers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"module": module, "module_prefix": module_prefix, "repo": repo, "collection": collection, "limit": limit})
""", {"module": module, "module_prefix": module_prefix, "module_suffix": module_suffix, "repo": repo, "collection": collection, "limit": limit})
else:
result = session.run("""
MATCH (importer:Symbol {collection: $collection})-[r:IMPORTS]->(imported:Symbol {collection: $collection})
WHERE (imported.name = $module OR imported.name STARTS WITH $module_prefix)
WHERE (imported.name = $module OR imported.name STARTS WITH $module_prefix OR imported.name ENDS WITH $module_suffix)
AND r.collection = $collection
RETURN importer.name as caller_symbol,
imported.name as callee_symbol,
Expand All @@ -1360,7 +1394,7 @@ def get_importers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"module": module, "module_prefix": module_prefix, "collection": collection, "limit": limit})
""", {"module": module, "module_prefix": module_prefix, "module_suffix": module_suffix, "collection": collection, "limit": limit})

return [dict(record) for record in result]

Expand Down Expand Up @@ -1733,4 +1767,4 @@ def compute_pagerank(

except Exception as e:
logger.error(f"Failed to compute PageRank: {e}")
return 0
return 0
Loading
Loading