Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: fix AzureSearch vectorstore asyncronous methods #24921

Merged
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 128 additions & 94 deletions libs/community/langchain_community/vectorstores/azuresearch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import itertools
import json
Expand Down Expand Up @@ -41,7 +42,12 @@

if TYPE_CHECKING:
from azure.search.documents import SearchClient, SearchItemPaged
from azure.search.documents.aio import SearchClient as AsyncSearchClient
from azure.search.documents.aio import (
AsyncSearchItemPaged,
)
from azure.search.documents.aio import (
SearchClient as AsyncSearchClient,
)
from azure.search.documents.indexes.models import (
CorsOptions,
ScoringProfile,
Expand Down Expand Up @@ -360,6 +366,28 @@ def __init__(
self._user_agent = user_agent
self._cors_options = cors_options

def __del__(self):
thedavgar marked this conversation as resolved.
Show resolved Hide resolved
# Close the sync client
self.client.close()

# Close the async client
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Schedule the coroutine to close the async client
loop.create_task(self.async_client.close())
else:
# If no event loop is running, run the coroutine directly
loop.run_until_complete(self.async_client.close())
except RuntimeError:
# Handle the case where there's no event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.async_client.close())
finally:
loop.close()

@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Support embedding object directly
Expand Down Expand Up @@ -518,21 +546,19 @@ async def aadd_embeddings(
ids.append(key)
# Upload data in batches
if len(data) == MAX_UPLOAD_BATCH_SIZE:
async with self.async_client as async_client:
response = await async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if not all(r.succeeded for r in response):
raise LangChainException(response)
# Reset data
data = []
response = await self.async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if not all(r.succeeded for r in response):
raise LangChainException(response)
# Reset data
data = []

# Considering case where data is an exact multiple of batch-size entries
if len(data) == 0:
return ids

# Upload data to index
async with self.async_client as async_client:
response = await async_client.upload_documents(documents=data)
response = await self.async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if all(r.succeeded for r in response):
return ids
Expand Down Expand Up @@ -566,9 +592,8 @@ async def adelete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool:
False otherwise.
"""
if ids:
async with self.async_client as async_client:
res = await async_client.delete_documents([{"id": i} for i in ids])
return len(res) > 0
res = await self.async_client.delete_documents([{"id": i} for i in ids])
return len(res) > 0
else:
return False

Expand Down Expand Up @@ -748,7 +773,7 @@ async def avector_search_with_score(
embedding, "", k, filters=filters, **kwargs
)

return _results_to_documents(results)
return await _aresults_to_documents(results)

def max_marginal_relevance_search_with_score(
self,
Expand Down Expand Up @@ -897,7 +922,7 @@ async def ahybrid_search_with_score(
embedding, query, k, filters=filters, **kwargs
)

return _results_to_documents(results)
return await _aresults_to_documents(results)

def hybrid_search_with_relevance_scores(
self,
Expand Down Expand Up @@ -1050,7 +1075,7 @@ async def _asimple_search(
*,
filters: Optional[str] = None,
**kwargs: Any,
) -> SearchItemPaged[dict]:
) -> AsyncSearchItemPaged[dict]:
"""Perform vector or hybrid search in the Azure search index.

Args:
Expand All @@ -1064,20 +1089,19 @@ async def _asimple_search(
"""
from azure.search.documents.models import VectorizedQuery

async with self.async_client as async_client:
return await async_client.search(
search_text=text_query,
vector_queries=[
VectorizedQuery(
vector=np.array(embedding, dtype=np.float32).tolist(),
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
],
filter=filters,
top=k,
**kwargs,
)
return await self.async_client.search(
search_text=text_query,
vector_queries=[
VectorizedQuery(
vector=np.array(embedding, dtype=np.float32).tolist(),
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
],
filter=filters,
top=k,
**kwargs,
)

def semantic_hybrid_search(
self, query: str, k: int = 4, **kwargs: Any
Expand Down Expand Up @@ -1289,71 +1313,68 @@ async def asemantic_hybrid_search_with_score_and_rerank(
from azure.search.documents.models import VectorizedQuery

vector = await self._aembed_query(query)
async with self.async_client as async_client:
results = await async_client.search(
search_text=query,
vector_queries=[
VectorizedQuery(
vector=np.array(vector, dtype=np.float32).tolist(),
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
],
filter=filters,
query_type="semantic",
semantic_configuration_name=self.semantic_configuration_name,
query_caption="extractive",
query_answer="extractive",
top=k,
**kwargs,
)
# Get Semantic Answers
semantic_answers = (await results.get_answers()) or []
semantic_answers_dict: Dict = {}
for semantic_answer in semantic_answers:
semantic_answers_dict[semantic_answer.key] = {
"text": semantic_answer.text,
"highlights": semantic_answer.highlights,
}
# Convert results to Document objects
docs = [
(
Document(
page_content=result.pop(FIELDS_CONTENT),
metadata={
**(
json.loads(result[FIELDS_METADATA])
if FIELDS_METADATA in result
else {
k: v
for k, v in result.items()
if k != FIELDS_CONTENT_VECTOR
}
results = await self.async_client.search(
search_text=query,
vector_queries=[
VectorizedQuery(
vector=np.array(vector, dtype=np.float32).tolist(),
k_nearest_neighbors=k,
fields=FIELDS_CONTENT_VECTOR,
)
],
filter=filters,
query_type="semantic",
semantic_configuration_name=self.semantic_configuration_name,
query_caption="extractive",
query_answer="extractive",
top=k,
**kwargs,
)
# Get Semantic Answers
semantic_answers = (await results.get_answers()) or []
semantic_answers_dict: Dict = {}
for semantic_answer in semantic_answers:
semantic_answers_dict[semantic_answer.key] = {
"text": semantic_answer.text,
"highlights": semantic_answer.highlights,
}
# Convert results to Document objects
docs = [
(
Document(
page_content=result.pop(FIELDS_CONTENT),
metadata={
**(
json.loads(result[FIELDS_METADATA])
if FIELDS_METADATA in result
else {
k: v
for k, v in result.items()
if k != FIELDS_CONTENT_VECTOR
}
),
**{
"captions": {
"text": result.get("@search.captions", [{}])[0].text,
"highlights": result.get("@search.captions", [{}])[
0
].highlights,
}
if result.get("@search.captions")
else {},
"answers": semantic_answers_dict.get(
result.get(FIELDS_ID, ""),
"",
),
**{
"captions": {
"text": result.get("@search.captions", [{}])[
0
].text,
"highlights": result.get("@search.captions", [{}])[
0
].highlights,
}
if result.get("@search.captions")
else {},
"answers": semantic_answers_dict.get(
result.get(FIELDS_ID, ""),
"",
),
},
},
),
float(result["@search.score"]),
float(result["@search.reranker_score"]),
)
async for result in results
]
return docs
},
),
float(result["@search.score"]),
float(result["@search.reranker_score"]),
)
async for result in results
]
return docs

@classmethod
def from_texts(
Expand Down Expand Up @@ -1631,6 +1652,19 @@ def _results_to_documents(
return docs


async def _aresults_to_documents(
results: AsyncSearchItemPaged[Dict],
) -> List[Tuple[Document, float]]:
docs = [
(
_result_to_document(result),
float(result["@search.score"]),
)
async for result in results
]
return docs


async def _areorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict],
query_embedding: np.ndarray,
Expand All @@ -1644,7 +1678,7 @@ async def _areorder_results_with_maximal_marginal_relevance(
float(result["@search.score"]),
result[FIELDS_CONTENT_VECTOR],
)
for result in results
async for result in results
]
documents, scores, vectors = map(list, zip(*docs))

Expand Down
Loading