Skip to content

Commit

Permalink
fix: score_threshold handling in vector search methods (#8356)
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 authored Sep 13, 2024
1 parent a45ac6a commit 08c4864
Show file tree
Hide file tree
Showing 14 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def delete_by_metadata_field(self, key: str, value: str) -> None:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models

score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
Expand Down Expand Up @@ -267,7 +267,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models

score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/chroma/chroma_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def text_exists(self, id: str) -> bool:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)

ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc

docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/milvus/milvus_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/myscale/myscale_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
metadata = {}

metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
Expand Down
4 changes: 2 additions & 2 deletions api/core/rag/datasource/vdb/oracle/oraclevector.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
Expand All @@ -212,7 +212,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
metadata = record.meta
score = 1 - dis
metadata["score"] = score
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
(json.dumps(query_vector),),
)
docs = []
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
Expand Down
4 changes: 2 additions & 2 deletions api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,13 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0),
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/relyt/relyt_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold:
docs.append(document)
return docs
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/tencent/tencent_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def delete_by_metadata_field(self, key: str, value: str) -> None:

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold

Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc

docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score
Expand Down

0 comments on commit 08c4864

Please sign in to comment.