Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.template
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
GOOGLE_API_KEY=
GOOGLE_API_KEY=your_api_key_here

# To use vertexai keep it true and false to use gemini
GEMINI_USE_VERTEX=true
Expand Down
24 changes: 9 additions & 15 deletions backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
from retrieval import get_retriever
from rrf import reciprocal_rank_fusion

# LLM (Gemini) client setup
try:
Expand Down Expand Up @@ -426,23 +427,16 @@ async def execute_search(state: AgentState) -> Dict[str, Any]:


def fuse_results(state: AgentState) -> AgentState:
print("--- Node: Result Fusion ---")
print("--- Node: Result Fusion (RRF) ---")
ks_results = state.get("ks_results", [])
vector_results = state.get("vector_results", [])
combined: Dict[str, dict] = {}
for res in vector_results:
if isinstance(res, dict):
doc_id = res.get("id") or res.get("_id") or f"vec_{len(combined)}"
combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * 0.6}
for res in ks_results:
if isinstance(res, dict):
doc_id = res.get("_id") or res.get("id") or f"ks_{len(combined)}"
if doc_id in combined:
combined[doc_id]["final_score"] += res.get("_score", 0) * 0.4
else:
combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4}
all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True)
print(f"Results summary: KS={len(ks_results)}, Vector={len(vector_results)}, Combined={len(all_sorted)}")

# We pass both lists to RRF. RRF handles deduplication and ranking.
# It takes care of ranking documents that appear in either or both lists.
all_sorted = reciprocal_rank_fusion([vector_results, ks_results], k=60, top_k=60)

print(f"RRF fusion: KS={len(ks_results)}, Vector={len(vector_results)} → Combined={len(all_sorted)} unique results")

page_size = 15
return {**state, "all_results": all_sorted, "final_results": all_sorted[:page_size]}

Expand Down
2 changes: 1 addition & 1 deletion backend/ks_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def global_fuzzy_keyword_search(keywords: Iterable[str], top_k: int = 20) -> Lis
all_configs = json.load(fh)
out: List[dict] = []
seen = set()
for kw in (keywords or []):
for kw in keywords or []:
if not kw:
continue
results = search_across_all_fields(kw, all_configs, threshold=0.8)
Expand Down
82 changes: 82 additions & 0 deletions backend/rrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from typing import List, Dict, Any, Set

logger = logging.getLogger("rrf")

def extract_doc_id(result: Dict[str, Any]) -> str:
"""
Safely extract a unique document ID from a search result dictionary.
Handles differences between Keyword Search (KS) and Vector Search formats.
"""
return str(result.get("id") or result.get("_id") or "")

def reciprocal_rank_fusion(
ranked_lists: List[List[Dict[str, Any]]],
k: int = 60,
top_k: int = 15
) -> List[Dict[str, Any]]:
"""
Combines multiple ranked lists of documents into a single ranked list using
Reciprocal Rank Fusion (RRF).

Formula: RRF_score(d) = sum(1 / (k + rank_i(d)))
where `rank_i(d)` is the 1-based index (rank) of document `d` in list `i`.

Args:
ranked_lists: A list of lists, where each inner list contains document dicts
ordered by their original search score (highest first).
k: The smoothing constant (default: 60, standard from literature).
top_k: The number of top fused results to return.

Returns:
A single fused list of document dictionaries, ordered by RRF score descending.
Each dictionary will have an added 'rrf_score' field and an updated 'final_score'
field for compatibility with the rest of the application.
"""
# 1. Initialize RRF scores for all unique document IDs
rrf_scores: Dict[str, float] = {}

# We also keep a mapping of ID -> original document dict
# so we can reconstruct the final list (we use the first occurrence we find)
doc_map: Dict[str, Dict[str, Any]] = {}

for ranked_list in ranked_lists:
for idx, doc in enumerate(ranked_list):
doc_id = extract_doc_id(doc)

# Skip if we couldn't resolve an ID (should theoretically not happen, but safe)
if not doc_id:
# Generate a weak fallback ID based on content hash or title context if needed,
# but for KnowledgeSpace, id or _id should always exist.
doc_id = str(hash(doc.get("title_guess", "unknown")))

rank = idx + 1 # RRF uses 1-based ranks

# Add the reciprocal rank score for this document
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + (1.0 / (k + rank))

# Store the underlying doc if we haven't seen it yet
if doc_id not in doc_map:
# Make a shallow copy to avoid mutating the original deeply
doc_map[doc_id] = dict(doc)

# 2. Sort documents by their accumulated RRF score descending
sorted_keys = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)
sorted_doc_ids: List[str] = list(sorted_keys)

# 3. Construct the final fused list
fused_results: List[Dict[str, Any]] = []

for doc_id in sorted_doc_ids[:top_k]:
doc = doc_map[doc_id]
score = rrf_scores[doc_id]

# Add tracking fields to the document
doc["rrf_score"] = score
# Maintain backward compatibility with agents.py expectations
doc["final_score"] = score

fused_results.append(doc)

logger.debug(f"Combined {len(ranked_lists)} lists into {len(fused_results)} results.")
return fused_results
75 changes: 75 additions & 0 deletions backend/tests/test_rrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from rrf import reciprocal_rank_fusion, extract_doc_id

def test_extract_doc_id():
assert extract_doc_id({"id": "123"}) == "123"
assert extract_doc_id({"_id": "456"}) == "456"
assert extract_doc_id({"id": "123", "_id": "456"}) == "123" # Prefers 'id'
assert extract_doc_id({}) == ""

def test_rrf_single_list():
list1 = [{"id": "A"}, {"id": "B"}, {"id": "C"}]
fused = reciprocal_rank_fusion([list1], k=60, top_k=10)

assert len(fused) == 3
assert fused[0]["id"] == "A"
assert fused[1]["id"] == "B"
assert fused[2]["id"] == "C"

# Check score math: A=1/61, B=1/62, C=1/63
assert fused[0]["rrf_score"] == 1 / 61
assert fused[1]["rrf_score"] == 1 / 62
assert fused[2]["rrf_score"] == 1 / 63

def test_rrf_two_lists_same_order():
list1 = [{"id": "A"}, {"id": "B"}]
list2 = [{"_id": "A"}, {"_id": "B"}] # Note list2 uses _id
fused = reciprocal_rank_fusion([list1, list2], k=60, top_k=10)

assert len(fused) == 2
assert fused[0]["id"] == "A" # Source dict comes from list1 first
assert fused[1]["id"] == "B"

# A is rank 1 in both: 1/61 + 1/61
assert fused[0]["rrf_score"] == (1/61) + (1/61)

def test_rrf_boosts_overlap():
# A is in both lists but ranked lower. B is rank 1 in list1 only. C is rank 1 in list2 only.
list1 = [{"id": "B"}, {"id": "A"}, {"id": "X"}]
list2 = [{"id": "C"}, {"id": "A"}, {"id": "Y"}]

fused = reciprocal_rank_fusion([list1, list2], k=60, top_k=10)

weights = {doc["id"]: doc["rrf_score"] for doc in fused}

# A: rank 2 + rank 2 = 1/62 + 1/62 = 0.032258
# B: rank 1 + none = 1/61 + 0 = 0.016393
# C: rank 1 + none = 1/61 + 0 = 0.016393

assert fused[0]["id"] == "A"
assert weights["A"] > weights["B"]
assert weights["A"] > weights["C"]

def test_rrf_empty_lists():
assert reciprocal_rank_fusion([], k=60) == []
assert reciprocal_rank_fusion([[], []], k=60) == []

list1 = [{"id": "A"}]
# Fuses one empty list and one populated list
fused = reciprocal_rank_fusion([list1, []], k=60)
assert len(fused) == 1
assert fused[0]["id"] == "A"

def test_rrf_top_k_truncates():
list1 = [{"id": str(i)} for i in range(100)]
fused = reciprocal_rank_fusion([list1], k=60, top_k=5)
assert len(fused) == 5
assert fused[-1]["id"] == "4" # Indices 0, 1, 2, 3, 4

def test_rrf_id_fallback():
# If a document doesn't have id or _id, the function uses a hash fallback.
# While relying on title_guess is weak, this ensures no crash.
list1 = [{"title_guess": "Unique Title"}, {"title_guess": "Another Title"}]
fused = reciprocal_rank_fusion([list1])
assert len(fused) == 2
assert fused[0].get("rrf_score") is not None