-
Notifications
You must be signed in to change notification settings - Fork 92
feat: Added Hybrid Search Config and Tests [1/N] #211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vishwarajanand
merged 24 commits into
langchain-ai:main
from
vishwarajanand:hybrid_search_1
Jun 3, 2025
Merged
Changes from 2 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
22088a1
feat: Added Hybrid Search Config and Tests [1/N]
vishwarajanand 30942ff
feat: create hybrid search capable vector store table [2/N]
vishwarajanand e641575
feat: adds hybrid search for async VS interface [3/N]
vishwarajanand 2a0bf0d
feat: adds hybrid search for sync VS interface [4/N]
vishwarajanand 0562678
Merge branch 'main' into hybrid_search_1
vishwarajanand 70ee300
fix: tests
vishwarajanand 5234648
fix: pr comments
vishwarajanand 73d4400
fix: lint
vishwarajanand 57ceb2c
fix: lint
vishwarajanand 678e7b1
Merge branch 'hybrid_search_1' into hybrid_search_2
vishwarajanand 7feb7a0
Merge branch 'hybrid_search_2' into hybrid_search_3
vishwarajanand ef349a3
pr comment: add disclaimer on slow query on config docstring
vishwarajanand ceabf10
pr comment: add disclaimer in engine table create
vishwarajanand 9611164
Merge branch 'hybrid_search_1' into hybrid_search_2
vishwarajanand 8a39e61
feat: address pr comments
vishwarajanand e5bd215
Merge branch 'hybrid_search_2' into hybrid_search_3
vishwarajanand 6854ee0
fix: tsv column name in tests
vishwarajanand 5bf1a4b
fix: add if exists in drop to avoid failures
vishwarajanand 4153c2d
Merge branch 'hybrid_search_3' into hybrid_search_4
vishwarajanand e092c82
fix: tests
vishwarajanand 08a4ff6
feat: adds hybrid search for sync VS interface [4/N]
vishwarajanand 076f0cb
feat: adds hybrid search for async VS interface [3/N]
vishwarajanand 620e3e5
feat: create hybrid search capable vector store table [2/N]
vishwarajanand 0d223fd
chore: fix lint
vishwarajanand File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from abc import ABC | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable, Optional, Sequence | ||
|
||
from sqlalchemy import RowMapping | ||
|
||
|
||
def weighted_sum_ranking( | ||
primary_search_results: Sequence[RowMapping], | ||
secondary_search_results: Sequence[RowMapping], | ||
primary_results_weight: float = 0.5, | ||
secondary_results_weight: float = 0.5, | ||
fetch_top_k: int = 4, | ||
) -> Sequence[dict[str, Any]]: | ||
""" | ||
Ranks documents using a weighted sum of scores from two sources. | ||
|
||
Args: | ||
primary_search_results: A list of (document, distance) tuples from | ||
the primary search. | ||
secondary_search_results: A list of (document, distance) tuples from | ||
the secondary search. | ||
primary_results_weight: The weight for the primary source's scores. | ||
Defaults to 0.5. | ||
secondary_results_weight: The weight for the secondary source's scores. | ||
Defaults to 0.5. | ||
fetch_top_k: The number of documents to fetch after merging the results. | ||
Defaults to 4. | ||
|
||
Returns: | ||
A list of (document, distance) tuples, sorted by weighted_score in | ||
descending order. | ||
""" | ||
|
||
# stores computed metric with provided distance metric and weights | ||
weighted_scores: dict[str, dict[str, Any]] = {} | ||
|
||
# Process results from primary source | ||
for row in primary_search_results: | ||
values = list(row.values()) | ||
doc_id = str(values[0]) # first value is doc_id | ||
distance = float(values[-1]) # type: ignore # last value is distance | ||
row_values = dict(row) | ||
row_values["distance"] = primary_results_weight * distance | ||
weighted_scores[doc_id] = row_values | ||
|
||
# Process results from secondary source, | ||
# adding to existing scores or creating new ones | ||
for row in secondary_search_results: | ||
values = list(row.values()) | ||
doc_id = str(values[0]) # first value is doc_id | ||
distance = float(values[-1]) # type: ignore # last value is distance | ||
primary_score = ( | ||
weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0 | ||
) | ||
row_values = dict(row) | ||
row_values["distance"] = distance * secondary_results_weight + primary_score | ||
weighted_scores[doc_id] = row_values | ||
|
||
# Sort the results by weighted score in descending order | ||
ranked_results = sorted( | ||
weighted_scores.values(), key=lambda item: item["distance"], reverse=True | ||
) | ||
return ranked_results[:fetch_top_k] | ||
|
||
|
||
def reciprocal_rank_fusion( | ||
primary_search_results: Sequence[RowMapping], | ||
secondary_search_results: Sequence[RowMapping], | ||
rrf_k: float = 60, | ||
fetch_top_k: int = 4, | ||
) -> Sequence[dict[str, Any]]: | ||
""" | ||
Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources. | ||
|
||
Args: | ||
primary_search_results: A list of (document, distance) tuples from | ||
the primary search. | ||
secondary_search_results: A list of (document, distance) tuples from | ||
the secondary search. | ||
rrf_k: The RRF parameter k. | ||
Defaults to 60. | ||
fetch_top_k: The number of documents to fetch after merging the results. | ||
Defaults to 4. | ||
|
||
Returns: | ||
A list of (document_id, rrf_score) tuples, sorted by rrf_score | ||
in descending order. | ||
""" | ||
rrf_scores: dict[str, dict[str, Any]] = {} | ||
|
||
# Process results from primary source | ||
for rank, row in enumerate( | ||
sorted(primary_search_results, key=lambda item: item["distance"], reverse=True) | ||
): | ||
values = list(row.values()) | ||
doc_id = str(values[0]) | ||
row_values = dict(row) | ||
primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 | ||
primary_score += 1.0 / (rank + rrf_k) | ||
row_values["distance"] = primary_score | ||
rrf_scores[doc_id] = row_values | ||
|
||
# Process results from secondary source | ||
for rank, row in enumerate( | ||
sorted( | ||
secondary_search_results, key=lambda item: item["distance"], reverse=True | ||
) | ||
): | ||
values = list(row.values()) | ||
doc_id = str(values[0]) | ||
row_values = dict(row) | ||
secondary_score = ( | ||
rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 | ||
) | ||
secondary_score += 1.0 / (rank + rrf_k) | ||
row_values["distance"] = secondary_score | ||
rrf_scores[doc_id] = row_values | ||
|
||
# Sort the results by rrf score in descending order | ||
# Sort the results by weighted score in descending order | ||
ranked_results = sorted( | ||
rrf_scores.values(), key=lambda item: item["distance"], reverse=True | ||
) | ||
# Extract only the RowMapping for the top results | ||
return ranked_results[:fetch_top_k] | ||
|
||
|
||
@dataclass | ||
class HybridSearchConfig(ABC): | ||
"""Google AlloyDB Vector Store Hybrid Search Config.""" | ||
|
||
tsv_column: Optional[str] = "" | ||
tsv_lang: Optional[str] = "pg_catalog.english" | ||
fts_query: Optional[str] = "" | ||
fusion_function: Callable[ | ||
[Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any] | ||
] = weighted_sum_ranking # Updated default | ||
fusion_function_parameters: dict[str, Any] = field(default_factory=dict) | ||
primary_top_k: int = 4 | ||
secondary_top_k: int = 4 | ||
index_name: str = "langchain_tsv_index" | ||
index_type: str = "GIN" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import pytest | ||
|
||
from langchain_postgres.v2.hybrid_search_config import (reciprocal_rank_fusion, | ||
weighted_sum_ranking) | ||
|
||
|
||
# Helper to create mock input items that mimic RowMapping for the fusion functions | ||
def get_row(doc_id: str, score: float, content: str = "content") -> dict: | ||
""" | ||
Simulates a RowMapping-like dictionary. | ||
The fusion functions expect to extract doc_id as the first value and | ||
the initial score/distance as the last value when casting values from RowMapping. | ||
They then operate on dictionaries, using the 'distance' key for the fused score. | ||
""" | ||
# Python dicts maintain insertion order (Python 3.7+). | ||
# This structure ensures list(row.values())[0] is doc_id and | ||
# list(row.values())[-1] is score. | ||
return {"id_val": doc_id, "content_field": content, "distance": score} | ||
|
||
|
||
class TestWeightedSumRanking: | ||
def test_empty_inputs(self): | ||
results = weighted_sum_ranking([], []) | ||
assert results == [] | ||
|
||
def test_primary_only(self): | ||
primary = [get_row("p1", 0.8), get_row("p2", 0.6)] | ||
# Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 | ||
results = weighted_sum_ranking( | ||
primary, [], primary_results_weight=0.5, secondary_results_weight=0.5 | ||
) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p1" | ||
assert results[0]["distance"] == pytest.approx(0.4) | ||
assert results[1]["id_val"] == "p2" | ||
assert results[1]["distance"] == pytest.approx(0.3) | ||
|
||
def test_secondary_only(self): | ||
secondary = [get_row("s1", 0.9), get_row("s2", 0.7)] | ||
# Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 | ||
results = weighted_sum_ranking( | ||
[], secondary, primary_results_weight=0.5, secondary_results_weight=0.5 | ||
) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "s1" | ||
assert results[0]["distance"] == pytest.approx(0.45) | ||
assert results[1]["id_val"] == "s2" | ||
assert results[1]["distance"] == pytest.approx(0.35) | ||
|
||
def test_mixed_results_default_weights(self): | ||
primary = [get_row("common", 0.8), get_row("p_only", 0.7)] | ||
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] | ||
# Weights are 0.5, 0.5 | ||
# common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85 | ||
# p_only_score = (0.7 * 0.5) = 0.35 | ||
# s_only_score = (0.6 * 0.5) = 0.30 | ||
# Order: common (0.85), p_only (0.35), s_only (0.30) | ||
|
||
results = weighted_sum_ranking(primary, secondary) | ||
assert len(results) == 3 | ||
assert results[0]["id_val"] == "common" | ||
assert results[0]["distance"] == pytest.approx(0.85) | ||
assert results[1]["id_val"] == "p_only" | ||
assert results[1]["distance"] == pytest.approx(0.35) | ||
assert results[2]["id_val"] == "s_only" | ||
assert results[2]["distance"] == pytest.approx(0.30) | ||
|
||
def test_mixed_results_custom_weights(self): | ||
primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2 | ||
secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4 | ||
# Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 | ||
|
||
results = weighted_sum_ranking( | ||
primary, secondary, primary_results_weight=0.2, secondary_results_weight=0.8 | ||
) | ||
assert len(results) == 1 | ||
assert results[0]["id_val"] == "d1" | ||
assert results[0]["distance"] == pytest.approx(0.6) | ||
|
||
def test_fetch_top_k(self): | ||
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] | ||
# Scores: 1.0, 0.9, 0.8, 0.7, 0.6 | ||
# Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 | ||
secondary = [] | ||
results = weighted_sum_ranking(primary, secondary, fetch_top_k=2) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p0" | ||
assert results[0]["distance"] == pytest.approx(0.5) | ||
assert results[1]["id_val"] == "p1" | ||
assert results[1]["distance"] == pytest.approx(0.45) | ||
|
||
|
||
class TestReciprocalRankFusion: | ||
def test_empty_inputs(self): | ||
results = reciprocal_rank_fusion([], []) | ||
assert results == [] | ||
|
||
def test_primary_only(self): | ||
primary = [ | ||
get_row("p1", 0.8), | ||
get_row("p2", 0.6), | ||
] # p1 rank 0, p2 rank 1 | ||
rrf_k = 60 | ||
# p1_score = 1 / (0 + 60) | ||
# p2_score = 1 / (1 + 60) | ||
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p1" | ||
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) | ||
assert results[1]["id_val"] == "p2" | ||
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_secondary_only(self): | ||
secondary = [ | ||
get_row("s1", 0.9), | ||
get_row("s2", 0.7), | ||
] # s1 rank 0, s2 rank 1 | ||
rrf_k = 60 | ||
results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "s1" | ||
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) | ||
assert results[1]["id_val"] == "s2" | ||
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_mixed_results_default_k(self): | ||
primary = [get_row("common", 0.8), get_row("p_only", 0.7)] | ||
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] | ||
rrf_k = 60 | ||
# common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k | ||
# p_only_score = (1/(1+k))_prim = 1/(k+1) | ||
# s_only_score = (1/(1+k))_sec = 1/(k+1) | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) | ||
assert len(results) == 3 | ||
assert results[0]["id_val"] == "common" | ||
assert results[0]["distance"] == pytest.approx(2.0 / rrf_k) | ||
# Check the next two elements, their order might vary due to tie in score | ||
next_ids = {results[1]["id_val"], results[2]["id_val"]} | ||
next_scores = {results[1]["distance"], results[2]["distance"]} | ||
assert next_ids == {"p_only", "s_only"} | ||
for score in next_scores: | ||
assert score == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_fetch_top_k_rrf(self): | ||
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] | ||
secondary = [] | ||
rrf_k = 1 | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k, fetch_top_k=2) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "p0" | ||
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) | ||
assert results[1]["id_val"] == "p1" | ||
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) | ||
|
||
def test_rrf_content_preservation(self): | ||
primary = [get_row("doc1", 0.9, content="Primary Content")] | ||
secondary = [get_row("doc1", 0.8, content="Secondary Content")] | ||
# RRF processes primary then secondary. If a doc is in both, | ||
# the content from the secondary list will overwrite primary's. | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) | ||
assert len(results) == 1 | ||
assert results[0]["id_val"] == "doc1" | ||
assert results[0]["content_field"] == "Secondary Content" | ||
|
||
# If only in primary | ||
results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) | ||
assert results_prim_only[0]["content_field"] == "Primary Content" | ||
|
||
def test_reordering_from_inputs_rrf(self): | ||
""" | ||
Tests that RRF fused ranking can be different from both primary and secondary | ||
input rankings. | ||
Primary Order: A, B, C | ||
Secondary Order: C, B, A | ||
Fused Order: (A, C) tied, then B | ||
""" | ||
primary = [ | ||
get_row("docA", 0.9), | ||
get_row("docB", 0.8), | ||
get_row("docC", 0.1), | ||
] | ||
secondary = [ | ||
get_row("docC", 0.9), | ||
get_row("docB", 0.5), | ||
get_row("docA", 0.2), | ||
] | ||
rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation | ||
# docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3 | ||
# docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1 | ||
# docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3 | ||
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) | ||
assert len(results) == 3 | ||
assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"} | ||
assert results[0]["distance"] == pytest.approx(4.0 / 3.0) | ||
assert results[1]["distance"] == pytest.approx(4.0 / 3.0) | ||
assert results[2]["id_val"] == "docB" | ||
assert results[2]["distance"] == pytest.approx(1.0) | ||
|
||
def test_reordering_from_inputs_weighted_sum(self): | ||
""" | ||
Tests that the fused ranking can be different from both primary and secondary | ||
input rankings. | ||
Primary Order: A (0.9), B (0.7) | ||
Secondary Order: B (0.8), A (0.2) | ||
Fusion (0.5/0.5 weights): | ||
docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55 | ||
docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75 | ||
Expected Fused Order: docB (0.75), docA (0.55) | ||
This is different from Primary (A,B) and Secondary (B,A) in terms of | ||
original score, but the fusion logic changes the effective contribution). | ||
""" | ||
primary = [get_row("docA", 0.9), get_row("docB", 0.7)] | ||
secondary = [get_row("docB", 0.8), get_row("docA", 0.2)] | ||
|
||
results = weighted_sum_ranking(primary, secondary) | ||
assert len(results) == 2 | ||
assert results[0]["id_val"] == "docB" | ||
assert results[0]["distance"] == pytest.approx(0.75) | ||
assert results[1]["id_val"] == "docA" | ||
assert results[1]["distance"] == pytest.approx(0.55) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.