Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Fix Bug in Azure Search Vectorstore search asyncronously #24081

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 65 additions & 74 deletions libs/community/langchain_community/vectorstores/azuresearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,21 @@ def __init__(
user_agent=user_agent,
cors_options=cors_options,
)
self.async_client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
semantic_configurations=semantic_configurations,
scoring_profiles=scoring_profiles,
default_scoring_profile=default_scoring_profile,
default_fields=default_fields,
user_agent=user_agent,
cors_options=cors_options,
async_=True,
)
self.search_type = search_type
self.semantic_configuration_name = semantic_configuration_name
self.fields = fields if fields else default_fields
Expand All @@ -338,23 +353,6 @@ def __init__(
self._user_agent = user_agent
self._cors_options = cors_options

def _async_client(self) -> AsyncSearchClient:
return _get_search_client(
self._azure_search_endpoint,
self._azure_search_key,
self._index_name,
semantic_configuration_name=self._semantic_configuration_name,
fields=self._fields,
vector_search=self._vector_search,
semantic_configurations=self._semantic_configurations,
scoring_profiles=self._scoring_profiles,
default_scoring_profile=self._default_scoring_profile,
default_fields=self._default_fields,
user_agent=self._user_agent,
cors_options=self._cors_options,
async_=True,
)

@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Support embedding object directly
Expand Down Expand Up @@ -513,7 +511,7 @@ async def aadd_embeddings(
ids.append(key)
# Upload data in batches
if len(data) == MAX_UPLOAD_BATCH_SIZE:
async with self._async_client() as async_client:
async with self.async_client as async_client:
response = await async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if not all(r.succeeded for r in response):
Expand All @@ -526,7 +524,7 @@ async def aadd_embeddings(
return ids

# Upload data to index
async with self._async_client() as async_client:
async with self.async_client as async_client:
response = await async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if all(r.succeeded for r in response):
Expand Down Expand Up @@ -561,7 +559,7 @@ async def adelete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool:
False otherwise.
"""
if ids:
async with self._async_client() as async_client:
async with self.async_client as async_client:
res = await async_client.delete_documents([{"id": i} for i in ids])
return len(res) > 0
else:
Expand Down Expand Up @@ -739,11 +737,11 @@ async def avector_search_with_score(
to the query and score for each
"""
embedding = await self._aembed_query(query)
docs, scores, _ = await self._asimple_search(
results = await self._asimple_search(
embedding, "", k, filters=filters, **kwargs
)

return list(zip(docs, scores))
return _results_to_documents(results)

def max_marginal_relevance_search_with_score(
self,
Expand Down Expand Up @@ -807,14 +805,12 @@ async def amax_marginal_relevance_search_with_score(
to the query and score for each
"""
embedding = await self._aembed_query(query)
docs, scores, vectors = await self._asimple_search(
results = await self._asimple_search(
embedding, "", fetch_k, filters=filters, **kwargs
)

return await self._areorder_results_with_maximal_marginal_relevance(
docs,
scores,
vectors,
return await _areorder_results_with_maximal_marginal_relevance(
results,
query_embedding=np.array(embedding),
lambda_mult=lambda_mult,
k=k,
Expand Down Expand Up @@ -890,11 +886,11 @@ async def ahybrid_search_with_score(
"""

embedding = await self._aembed_query(query)
docs, scores, _ = await self._asimple_search(
results = await self._asimple_search(
embedding, query, k, filters=filters, **kwargs
)

return list(zip(docs, scores))
return _results_to_documents(results)

def hybrid_search_with_relevance_scores(
self,
Expand Down Expand Up @@ -992,14 +988,12 @@ async def ahybrid_max_marginal_relevance_search_with_score(
"""

embedding = await self._aembed_query(query)
docs, scores, vectors = await self._asimple_search(
results = await self._asimple_search(
embedding, query, fetch_k, filters=filters, **kwargs
)

return await self._areorder_results_with_maximal_marginal_relevance(
docs,
scores,
vectors,
return await _areorder_results_with_maximal_marginal_relevance(
results,
query_embedding=np.array(embedding),
lambda_mult=lambda_mult,
k=k,
Expand Down Expand Up @@ -1049,7 +1043,7 @@ async def _asimple_search(
*,
filters: Optional[str] = None,
**kwargs: Any,
) -> Tuple[List[Document], List[float], List[List[float]]]:
) -> SearchItemPaged[dict]:
"""Perform vector or hybrid search in the Azure search index.

Args:
Expand All @@ -1063,8 +1057,8 @@ async def _asimple_search(
"""
from azure.search.documents.models import VectorizedQuery

async with self._async_client() as async_client:
results = await async_client.search(
async with self.async_client as async_client:
return await async_client.search(
search_text=text_query,
vector_queries=[
VectorizedQuery(
Expand All @@ -1077,18 +1071,6 @@ async def _asimple_search(
top=k,
**kwargs,
)
docs = [
(
_result_to_document(result),
float(result["@search.score"]),
result[FIELDS_CONTENT_VECTOR],
)
async for result in results
]
if not docs:
raise ValueError(f"No {docs=}")
documents, scores, vectors = map(list, zip(*docs))
return documents, scores, vectors

def semantic_hybrid_search(
self, query: str, k: int = 4, **kwargs: Any
Expand Down Expand Up @@ -1300,7 +1282,7 @@ async def asemantic_hybrid_search_with_score_and_rerank(
from azure.search.documents.models import VectorizedQuery

vector = await self._aembed_query(query)
async with self._async_client() as async_client:
async with self.async_client as async_client:
results = await async_client.search(
search_text=query,
vector_queries=[
Expand Down Expand Up @@ -1475,30 +1457,6 @@ def from_embeddings(
azure_search.add_embeddings(text_embeddings, metadatas, **kwargs)
return azure_search

async def _areorder_results_with_maximal_marginal_relevance(
self,
documents: List[Document],
scores: List[float],
vectors: List[List[float]],
query_embedding: np.ndarray,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[Tuple[Document, float]]:
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
query_embedding, vectors, k=k, lambda_mult=lambda_mult
)

# Reorder the values and return.
ret: List[Tuple[Document, float]] = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
ret.append((documents[x], scores[x])) # type: ignore

return ret

def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore.

Expand Down Expand Up @@ -1666,6 +1624,39 @@ def _results_to_documents(
return docs


async def _areorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict],
query_embedding: np.ndarray,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[Tuple[Document, float]]:
# Convert results to Document objects
docs = [
(
_result_to_document(result),
float(result["@search.score"]),
result[FIELDS_CONTENT_VECTOR],
)
for result in results
]
documents, scores, vectors = map(list, zip(*docs))

# Get the new order of results.
new_ordering = maximal_marginal_relevance(
query_embedding, vectors, k=k, lambda_mult=lambda_mult
)

# Reorder the values and return.
ret: List[Tuple[Document, float]] = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
ret.append((documents[x], scores[x])) # type: ignore

return ret


def _reorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict],
query_embedding: np.ndarray,
Expand Down
Loading