Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,16 +580,16 @@ async def __query_collection(
For best hybrid search performance, consider creating a TSV column
and adding GIN index.
"""
if not k:
k = (
max(
self.k,
self.hybrid_search_config.primary_top_k,
self.hybrid_search_config.secondary_top_k,
)
if self.hybrid_search_config
else self.k
)
hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)

final_k = k if k is not None else self.k

dense_limit = final_k
if hybrid_search_config:
dense_limit = hybrid_search_config.primary_top_k

operator = self.distance_strategy.operator
search_function = self.distance_strategy.search_function

Expand Down Expand Up @@ -617,9 +617,9 @@ async def __query_collection(
embedding_data_string = ":query_embedding"
where_filters = f"WHERE {safe_filter}" if safe_filter else ""
dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :dense_limit;
"""
param_dict = {"query_embedding": query_embedding, "k": k}
param_dict = {"query_embedding": query_embedding, "dense_limit": dense_limit}
if filter_dict:
param_dict.update(filter_dict)
if self.index_query_options:
Expand All @@ -637,16 +637,13 @@ async def __query_collection(
result_map = result.mappings()
dense_results = result_map.fetchall()

hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)
fts_query = (
hybrid_search_config.fts_query
if hybrid_search_config and hybrid_search_config.fts_query
else kwargs.get("fts_query", "")
)
if hybrid_search_config and fts_query:
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = final_k
# do the sparse query
lang = (
f"'{hybrid_search_config.tsv_lang}',"
Expand Down