Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Use _AstraDBCollectionEnvironment in AstraDB VectorStore (community) #17635

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 46 additions & 163 deletions libs/community/langchain_community/vectorstores/astradb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import asyncio
import uuid
import warnings
from asyncio import Task
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
Expand All @@ -17,17 +16,21 @@
Tuple,
Type,
TypeVar,
Union,
)

import numpy as np
from langchain_core._api.deprecation import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.utils.iter import batch_iterate
from langchain_core.vectorstores import VectorStore

from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:
Expand Down Expand Up @@ -167,28 +170,12 @@ def __init__(
bulk_insert_batch_concurrency: Optional[int] = None,
bulk_insert_overwrite_concurrency: Optional[int] = None,
bulk_delete_concurrency: Optional[int] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""
Create an AstraDB vector store object. See class docstring for help.
"""
try:
from astrapy.db import AstraDB as LibAstraDB
from astrapy.db import AstraDBCollection
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)

# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDB if passing 'token' and 'api_endpoint'."
)

self.embedding = embedding
self.collection_name = collection_name
self.token = token
Expand All @@ -207,105 +194,35 @@ def __init__(
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
)
# "vector-related" settings
self._embedding_dimension: Optional[int] = None
self.metric = metric
embedding_dimension: Union[int, Awaitable[int], None] = None
if setup_mode == SetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()

self.astra_db = astra_db_client
self.async_astra_db = async_astra_db_client
self.collection = None
self.async_collection = None

if token and api_endpoint:
self.astra_db = LibAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
try:
from astrapy.db import AsyncAstraDB

self.async_astra_db = AsyncAstraDB(
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace,
)
except (ImportError, ModuleNotFoundError):
pass

if self.astra_db is not None:
self.collection = AstraDBCollection(
collection_name=self.collection_name,
astra_db=self.astra_db,
)

self.async_setup_db_task: Optional[Task] = None
if self.async_astra_db is not None:
from astrapy.db import AsyncAstraDBCollection

self.async_collection = AsyncAstraDBCollection(
collection_name=self.collection_name,
astra_db=self.async_astra_db,
)
try:
self.async_setup_db_task = asyncio.create_task(
self._setup_db(pre_delete_collection)
)
except RuntimeError:
pass

if self.async_setup_db_task is None:
if not pre_delete_collection:
self._provision_collection()
else:
self.clear()

def _ensure_astra_db_client(self): # type: ignore[no-untyped-def]
if not self.astra_db:
raise ValueError("Missing AstraDB client")

async def _setup_db(self, pre_delete_collection: bool) -> None:
if pre_delete_collection:
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name,
)
await self._aprovision_collection()

async def _ensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
self._embedding_dimension = len(
self.embedding.embed_query("This is a sample sentence.")
)
return self._embedding_dimension

def _provision_collection(self) -> None:
"""
Run the API invocation to create the collection on the backend.

Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
self.astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
)
self.astra_db = self.astra_env.astra_db
self.async_astra_db = self.astra_env.async_astra_db
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

async def _aprovision_collection(self) -> None:
"""
Run the API invocation to create the collection on the backend.
def _get_embedding_dimension(self) -> int:
return len(self.embedding.embed_query(text="This is a sample sentence."))

Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
await self.async_astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
)
async def _aget_embedding_dimension(self) -> int:
return len(await self.embedding.aembed_query(text="This is a sample sentence."))

@property
def embeddings(self) -> Embeddings:
Expand All @@ -326,22 +243,20 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:

def clear(self) -> None:
"""Empty the collection of all its stored entries."""
self.delete_collection()
self._provision_collection()
self.astra_env.ensure_db_setup()
self.collection.delete_many({})

async def aclear(self) -> None:
"""Empty the collection of all its stored entries."""
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.clear)
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many({}) # type: ignore[union-attr]

def delete_by_document_id(self, document_id: str) -> bool:
"""
Remove a single document from the store, given its document_id (str).
Return True if a document has indeed been deleted, False if ID not found.
"""
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
Expand All @@ -352,9 +267,7 @@ async def adelete_by_document_id(self, document_id: str) -> bool:
Remove a single document from the store, given its document_id (str).
Return True if a document has indeed been deleted, False if ID not found.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(None, self.delete_by_document_id, document_id)
await self.astra_env.aensure_db_setup()
deletion_response = await self.async_collection.delete_one(document_id)
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
Expand Down Expand Up @@ -439,8 +352,8 @@ def delete_collection(self) -> None:
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
self._ensure_astra_db_client()
self.astra_db.delete_collection( # type: ignore[union-attr]
self.astra_env.ensure_db_setup()
self.astra_db.delete_collection(
collection_name=self.collection_name,
)

Expand All @@ -451,10 +364,8 @@ async def adelete_collection(self) -> None:
Stored data is lost and unrecoverable, resources are freed.
Use with caution.
"""
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.delete_collection)
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
await self.astra_env.aensure_db_setup()
await self.async_astra_db.delete_collection(
collection_name=self.collection_name,
)

Expand Down Expand Up @@ -569,7 +480,7 @@ def add_texts(
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()

embedding_vectors = self.embedding.embed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
Expand Down Expand Up @@ -655,22 +566,13 @@ async def aadd_texts(
Returns:
List[str]: List of ids of the added texts.
"""
await self._ensure_db_setup()
if not self.async_collection:
await super().aadd_texts(
texts,
metadatas,
ids=ids,
batch_size=batch_size,
batch_concurrency=batch_concurrency,
overwrite_concurrency=overwrite_concurrency,
)
if kwargs:
warnings.warn(
"Method 'aadd_texts' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
await self.astra_env.aensure_db_setup()

embedding_vectors = await self.embedding.aembed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
Expand Down Expand Up @@ -731,7 +633,7 @@ def similarity_search_with_score_id_by_vector(
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
Expand Down Expand Up @@ -773,15 +675,7 @@ async def asimilarity_search_with_score_id_by_vector(
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(
None,
self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type]
embedding,
k,
filter,
)
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
return [
Expand Down Expand Up @@ -1010,7 +904,7 @@ def max_marginal_relevance_search_by_vector(
Returns:
List of Documents selected by maximal marginal relevance.
"""
self._ensure_astra_db_client()
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)

prefetch_hits = list(
Expand Down Expand Up @@ -1051,18 +945,7 @@ async def amax_marginal_relevance_search_by_vector(
Returns:
List of Documents selected by maximal marginal relevance.
"""
await self._ensure_db_setup()
if not self.async_collection:
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector,
embedding,
k,
fetch_k,
lambda_mult,
filter,
**kwargs,
)
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)

prefetch_hits = [
Expand Down
Loading