Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 115 additions & 22 deletions plugins/neo4j_graph/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,14 @@ def _perform_backfill(self, collection: str, qdrant_client) -> None:

logger.info(f"Auto-backfill complete: {total_edges} edges from {total_points} points")

# Compute PageRank after populating edges
if total_edges > 0:
try:
pr_count = self.compute_pagerank(collection)
logger.info(f"Auto-backfill: computed PageRank for {pr_count} nodes")
except Exception as e:
logger.warning(f"Auto-backfill: PageRank computation failed: {e}")

except Exception as e:
logger.error(f"Auto-backfill failed: {e}")

Expand Down Expand Up @@ -815,17 +823,24 @@ def get_callers(
repo: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""Find all callers of a symbol using Cypher."""
"""Find all callers of a symbol using Cypher.

Supports both exact matches and class-level queries (includes methods).
For "MyClass", also matches callers of "MyClass.method" etc.
"""
driver = self._get_driver()
db = self._get_database()
collection = self._get_collection(graph_store)
# Pattern for class + methods: exact match OR starts with "Class."
symbol_prefix = f"{symbol}."

try:
with driver.session(database=db) as session:
if repo and repo != "*":
result = session.run("""
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {name: $symbol, collection: $collection})
WHERE r.collection = $collection AND (r.repo = $repo OR callee.repo = $repo)
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (callee.name = $symbol OR callee.name STARTS WITH $symbol_prefix)
AND r.collection = $collection AND (r.repo = $repo OR callee.repo = $repo)
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
r.caller_path as caller_path,
Expand All @@ -836,11 +851,12 @@ def get_callers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "repo": repo, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "repo": repo, "collection": collection, "limit": limit})
else:
result = session.run("""
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {name: $symbol, collection: $collection})
WHERE r.collection = $collection
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (callee.name = $symbol OR callee.name STARTS WITH $symbol_prefix)
AND r.collection = $collection
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
r.caller_path as caller_path,
Expand All @@ -851,7 +867,7 @@ def get_callers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "collection": collection, "limit": limit})

return [dict(record) for record in result]

Expand All @@ -866,17 +882,24 @@ def get_callees(
repo: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""Find all symbols called by a symbol using Cypher."""
"""Find all symbols called by a symbol using Cypher.

Supports both exact matches and class-level queries (includes methods).
For "MyClass", also matches "MyClass.method" etc.
"""
driver = self._get_driver()
db = self._get_database()
collection = self._get_collection(graph_store)
# Pattern for class + methods: exact match OR starts with "Class."
symbol_prefix = f"{symbol}."

try:
with driver.session(database=db) as session:
if repo and repo != "*":
result = session.run("""
MATCH (caller:Symbol {name: $symbol, collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE r.collection = $collection AND (r.repo = $repo OR caller.repo = $repo)
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (caller.name = $symbol OR caller.name STARTS WITH $symbol_prefix)
AND r.collection = $collection AND (r.repo = $repo OR caller.repo = $repo)
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
r.caller_path as caller_path,
Expand All @@ -888,11 +911,12 @@ def get_callees(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "repo": repo, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "repo": repo, "collection": collection, "limit": limit})
else:
result = session.run("""
MATCH (caller:Symbol {name: $symbol, collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE r.collection = $collection
MATCH (caller:Symbol {collection: $collection})-[r:CALLS]->(callee:Symbol {collection: $collection})
WHERE (caller.name = $symbol OR caller.name STARTS WITH $symbol_prefix)
AND r.collection = $collection
RETURN caller.name as caller_symbol,
callee.name as callee_symbol,
r.caller_path as caller_path,
Expand All @@ -904,7 +928,7 @@ def get_callees(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"symbol": symbol, "collection": collection, "limit": limit})
""", {"symbol": symbol, "symbol_prefix": symbol_prefix, "collection": collection, "limit": limit})

return [dict(record) for record in result]

Expand All @@ -919,17 +943,24 @@ def get_importers(
repo: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""Find all files that import a module using Cypher."""
"""Find all files that import a module using Cypher.

Supports both exact matches and submodule queries.
For "mypackage", also matches importers of "mypackage.submodule" etc.
"""
driver = self._get_driver()
db = self._get_database()
collection = self._get_collection(graph_store)
# Pattern for module + submodules: exact match OR starts with "module."
module_prefix = f"{module}."

try:
with driver.session(database=db) as session:
if repo and repo != "*":
result = session.run("""
MATCH (importer:Symbol {collection: $collection})-[r:IMPORTS]->(imported:Symbol {name: $module, collection: $collection})
WHERE r.collection = $collection AND (r.repo = $repo OR imported.repo = $repo)
MATCH (importer:Symbol {collection: $collection})-[r:IMPORTS]->(imported:Symbol {collection: $collection})
WHERE (imported.name = $module OR imported.name STARTS WITH $module_prefix)
AND r.collection = $collection AND (r.repo = $repo OR imported.repo = $repo)
RETURN importer.name as caller_symbol,
imported.name as callee_symbol,
r.caller_path as caller_path,
Expand All @@ -938,11 +969,12 @@ def get_importers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"module": module, "repo": repo, "collection": collection, "limit": limit})
""", {"module": module, "module_prefix": module_prefix, "repo": repo, "collection": collection, "limit": limit})
else:
result = session.run("""
MATCH (importer:Symbol {collection: $collection})-[r:IMPORTS]->(imported:Symbol {name: $module, collection: $collection})
WHERE r.collection = $collection
MATCH (importer:Symbol {collection: $collection})-[r:IMPORTS]->(imported:Symbol {collection: $collection})
WHERE (imported.name = $module OR imported.name STARTS WITH $module_prefix)
AND r.collection = $collection
RETURN importer.name as caller_symbol,
imported.name as callee_symbol,
r.caller_path as caller_path,
Expand All @@ -951,7 +983,7 @@ def get_importers(
r.edge_id as edge_id,
r.caller_point_id as caller_point_id
LIMIT $limit
""", {"module": module, "collection": collection, "limit": limit})
""", {"module": module, "module_prefix": module_prefix, "collection": collection, "limit": limit})

return [dict(record) for record in result]

Expand Down Expand Up @@ -1080,4 +1112,65 @@ def resolve_import(

except Exception as e:
logger.debug(f"Failed to resolve import {import_name}: {e}")
return None
return None

def compute_pagerank(
self,
graph_store: str,
repo: Optional[str] = None,
timeout: int = 120,
) -> int:
"""Compute PageRank for code symbols (importance scoring).

Uses simple in-degree approximation as a fallback since GDS may not be available.
All nodes get a base rank (0.001), nodes with incoming edges get rank proportional
to in-degree.

Args:
graph_store: Graph store name (collection)
repo: Optional repository filter
timeout: Transaction timeout in seconds

Returns:
Count of nodes updated
"""
driver = self._get_driver()
db = self._get_database()
collection = self._get_collection(graph_store)

try:
with driver.session(database=db) as session:
# Simple in-degree approximation with OPTIONAL MATCH
# Ensures ALL nodes get a base rank, not just those with incoming edges
with session.begin_transaction(timeout=timeout) as tx:
if repo and repo != "*":
result = tx.run("""
MATCH (n:Symbol {collection: $collection})
WHERE n.repo = $repo
OPTIONAL MATCH (n)<-[r:CALLS|IMPORTS]-()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In compute_pagerank, in_degree is computed via OPTIONAL MATCH (n)<-[r:CALLS|IMPORTS]-() without constraining r.collection (and r.repo when repo is provided), so edges from other collections/repos could skew the ranking. Consider scoping the incoming relationships to the same collection/repo as n to keep PageRank isolated per graph.

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎

WITH n, count(r) AS in_degree
SET n.pagerank = CASE WHEN in_degree > 0
THEN toFloat(in_degree) / 100.0
ELSE 0.001 END
RETURN count(n) AS cnt
""", collection=collection, repo=repo)
else:
result = tx.run("""
MATCH (n:Symbol {collection: $collection})
OPTIONAL MATCH (n)<-[r:CALLS|IMPORTS]-()
WITH n, count(r) AS in_degree
SET n.pagerank = CASE WHEN in_degree > 0
THEN toFloat(in_degree) / 100.0
ELSE 0.001 END
RETURN count(n) AS cnt
""", collection=collection)

record = result.single()
cnt = record["cnt"] if record else 0
tx.commit()
logger.info(f"Computed PageRank for {cnt} nodes in {collection}")
return cnt

except Exception as e:
logger.error(f"Failed to compute PageRank: {e}")
return 0
15 changes: 11 additions & 4 deletions plugins/neo4j_graph/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,20 +446,27 @@ def compute_pagerank(

except Exception:
# Fallback: simple in-degree approximation with timeout
# Uses OPTIONAL MATCH to give base rank to ALL nodes, not just those with incoming edges
with session.begin_transaction(timeout=timeout) as tx:
if repo:
result = tx.run("""
MATCH (n)<-[r:CALLS|IMPORTS]-()
MATCH (n:Symbol)
WHERE n.repo = $repo
OPTIONAL MATCH (n)<-[r:CALLS|IMPORTS]-()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the fallback PageRank path, the OPTIONAL MATCH (n)<-[r:CALLS|IMPORTS]-() doesn’t constrain relationships by repo, unlike the GDS rel_query branch (WHERE a.repo = $repo). This can make repo-scoped PageRank inconsistent with the GDS path when repo is provided.

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎

WITH n, count(r) AS in_degree
SET n.pagerank = toFloat(in_degree) / 100.0
SET n.pagerank = CASE WHEN in_degree > 0
THEN toFloat(in_degree) / 100.0
ELSE 0.001 END
RETURN count(n) AS cnt
""", repo=repo)
else:
result = tx.run("""
MATCH (n)<-[r:CALLS|IMPORTS]-()
MATCH (n:Symbol)
OPTIONAL MATCH (n)<-[r:CALLS|IMPORTS]-()
WITH n, count(r) AS in_degree
SET n.pagerank = toFloat(in_degree) / 100.0
SET n.pagerank = CASE WHEN in_degree > 0
THEN toFloat(in_degree) / 100.0
ELSE 0.001 END
RETURN count(n) AS cnt
""")
cnt = result.single()["cnt"]
Expand Down
1 change: 1 addition & 0 deletions scripts/benchmarks/coir/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ async def search_one(qid: str, query_text: str) -> tuple:
rerank_top_n=rerank_top_n if self.rerank_enabled else None,
rerank_return_m=top_k if self.rerank_enabled else None,
mode=self.mode,
output_format="json", # Ensure dict results, not TOON strings
)
# Extract scores
doc_scores = {}
Expand Down
Loading
Loading