From 8b022f309f21f5a1cda82246ba68324949d01f49 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Sat, 4 May 2024 16:02:05 +0530 Subject: [PATCH] refactor: Configure collection names Qdrant --- src/vanna/qdrant/qdrant.py | 128 ++++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 57 deletions(-) diff --git a/src/vanna/qdrant/qdrant.py b/src/vanna/qdrant/qdrant.py index 3730af0d..48c0023e 100644 --- a/src/vanna/qdrant/qdrant.py +++ b/src/vanna/qdrant/qdrant.py @@ -7,48 +7,51 @@ from ..base import VannaBase from ..utils import deterministic_uuid -DOCUMENTATION_COLLECTION_NAME = "documentation" -DDL_COLLECTION_NAME = "ddl" -SQL_COLLECTION_NAME = "sql" SCROLL_SIZE = 1000 -ID_SUFFIXES = { - DDL_COLLECTION_NAME: "ddl", - DOCUMENTATION_COLLECTION_NAME: "doc", - SQL_COLLECTION_NAME: "sql", -} - class Qdrant_VectorStore(VannaBase): - """Vectorstore implementation using Qdrant - https://qdrant.tech/""" + """ + Vectorstore implementation using Qdrant - https://qdrant.tech/ + + Args: + - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`. + - client: A `qdrant_client.QdrantClient` instance. Overrides other config options. + - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter. + - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`. + - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods. + - https: If `true` - use HTTPS(SSL) protocol. Default: `None` + - api_key: API key for authentication in Qdrant Cloud. Default: `None` + - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC. + - path: Persistence path for QdrantLocal. Default: `None`. + - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`. + - n_results: Number of results to return from similarity search. Defaults to 10. + - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`. + Defaults to `"BAAI/bge-small-en-v1.5"`. + - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method. + - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`. + - documentation_collection_name: Name of the collection to store documentation. Defaults to `"documentation"`. + - ddl_collection_name: Name of the collection to store DDL. Defaults to `"ddl"`. + - sql_collection_name: Name of the collection to store SQL. Defaults to `"sql"`. + + Raises: + TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance + """ + + documentation_collection_name = "documentation" + ddl_collection_name = "ddl" + sql_collection_name = "sql" + + id_suffixes = { + ddl_collection_name: "ddl", + documentation_collection_name: "doc", + sql_collection_name: "sql", + } def __init__( self, config={}, ): - """ - Vectorstore implementation using Qdrant - https://qdrant.tech/ - - Args: - - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`. - - client: A `qdrant_client.QdrantClient` instance. Overrides other config options. - - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter. - - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`. - - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods. - - https: If `true` - use HTTPS(SSL) protocol. Default: `None` - - api_key: API key for authentication in Qdrant Cloud. Default: `None` - - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC. - - path: Persistence path for QdrantLocal. Default: `None`. - - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`. - - n_results: Number of results to return from similarity search. Defaults to 10. - - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`. - Defaults to `"BAAI/bge-small-en-v1.5"`. - - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method. - - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`. - - Raises: - TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance - """ VannaBase.__init__(self, config=config) client = config.get("client") @@ -75,6 +78,15 @@ def __init__( self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") self.collection_params = config.get("collection_params", {}) self.distance_metric = config.get("distance_metric", models.Distance.COSINE) + self.documentation_collection_name = config.get( + "documentation_collection_name", self.documentation_collection_name + ) + self.ddl_collection_name = config.get( + "ddl_collection_name", self.ddl_collection_name + ) + self.sql_collection_name = config.get( + "sql_collection_name", self.sql_collection_name + ) self._setup_collections() @@ -83,7 +95,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: id = deterministic_uuid(question_answer) self._client.upsert( - SQL_COLLECTION_NAME, + self.sql_collection_name, points=[ models.PointStruct( id=id, @@ -96,12 +108,12 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: ], ) - return self._format_point_id(id, SQL_COLLECTION_NAME) + return self._format_point_id(id, self.sql_collection_name) def add_ddl(self, ddl: str, **kwargs) -> str: id = deterministic_uuid(ddl) self._client.upsert( - DDL_COLLECTION_NAME, + self.ddl_collection_name, points=[ models.PointStruct( id=id, @@ -112,13 +124,13 @@ def add_ddl(self, ddl: str, **kwargs) -> str: ) ], ) - return self._format_point_id(id, DDL_COLLECTION_NAME) + return self._format_point_id(id, self.ddl_collection_name) def add_documentation(self, documentation: str, **kwargs) -> str: id = deterministic_uuid(documentation) self._client.upsert( - DOCUMENTATION_COLLECTION_NAME, + self.documentation_collection_name, points=[ models.PointStruct( id=id, @@ -130,16 +142,17 @@ def add_documentation(self, documentation: str, **kwargs) -> str: ], ) - return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME) + return self._format_point_id(id, self.documentation_collection_name) def get_training_data(self, **kwargs) -> pd.DataFrame: df = pd.DataFrame() - if sql_data := self._get_all_points(SQL_COLLECTION_NAME): + if sql_data := self._get_all_points(self.sql_collection_name): question_list = [data.payload["question"] for data in sql_data] sql_list = [data.payload["sql"] for data in sql_data] id_list = [ - self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data + self._format_point_id(data.id, self.sql_collection_name) + for data in sql_data ] df_sql = pd.DataFrame( @@ -154,10 +167,11 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: df = pd.concat([df, df_sql]) - if ddl_data := self._get_all_points(DDL_COLLECTION_NAME): + if ddl_data := self._get_all_points(self.ddl_collection_name): ddl_list = [data.payload["ddl"] for data in ddl_data] id_list = [ - self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in ddl_data + self._format_point_id(data.id, self.ddl_collection_name) + for data in ddl_data ] df_ddl = pd.DataFrame( @@ -172,10 +186,10 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: df = pd.concat([df, df_ddl]) - if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME): + if doc_data := self._get_all_points(self.documentation_collection_name): document_list = [data.payload["documentation"] for data in doc_data] id_list = [ - self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME) + self._format_point_id(data.id, self.documentation_collection_name) for data in doc_data ] @@ -210,7 +224,7 @@ def remove_collection(self, collection_name: str) -> bool: Returns: bool: True if collection is deleted, False otherwise """ - if collection_name in ID_SUFFIXES.keys(): + if collection_name in self.id_suffixes.keys(): self._client.delete_collection(collection_name) self._setup_collections() return True @@ -223,7 +237,7 @@ def embeddings_dimension(self): def get_similar_question_sql(self, question: str, **kwargs) -> list: results = self._client.search( - SQL_COLLECTION_NAME, + self.sql_collection_name, query_vector=self.generate_embedding(question), limit=self.n_results, with_payload=True, @@ -233,7 +247,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list: def get_related_ddl(self, question: str, **kwargs) -> list: results = self._client.search( - DDL_COLLECTION_NAME, + self.ddl_collection_name, query_vector=self.generate_embedding(question), limit=self.n_results, with_payload=True, @@ -243,7 +257,7 @@ def get_related_ddl(self, question: str, **kwargs) -> list: def get_related_documentation(self, question: str, **kwargs) -> list: results = self._client.search( - DOCUMENTATION_COLLECTION_NAME, + self.documentation_collection_name, query_vector=self.generate_embedding(question), limit=self.n_results, with_payload=True, @@ -282,9 +296,9 @@ def _get_all_points(self, collection_name: str): return results def _setup_collections(self): - if not self._client.collection_exists(SQL_COLLECTION_NAME): + if not self._client.collection_exists(self.sql_collection_name): self._client.create_collection( - collection_name=SQL_COLLECTION_NAME, + collection_name=self.sql_collection_name, vectors_config=models.VectorParams( size=self.embeddings_dimension, distance=self.distance_metric, @@ -292,18 +306,18 @@ def _setup_collections(self): **self.collection_params, ) - if not self._client.collection_exists(DDL_COLLECTION_NAME): + if not self._client.collection_exists(self.ddl_collection_name): self._client.create_collection( - collection_name=DDL_COLLECTION_NAME, + collection_name=self.ddl_collection_name, vectors_config=models.VectorParams( size=self.embeddings_dimension, distance=self.distance_metric, ), **self.collection_params, ) - if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME): + if not self._client.collection_exists(self.documentation_collection_name): self._client.create_collection( - collection_name=DOCUMENTATION_COLLECTION_NAME, + collection_name=self.documentation_collection_name, vectors_config=models.VectorParams( size=self.embeddings_dimension, distance=self.distance_metric, @@ -312,11 +326,11 @@ def _setup_collections(self): ) def _format_point_id(self, id: str, collection_name: str) -> str: - return "{0}-{1}".format(id, ID_SUFFIXES[collection_name]) + return "{0}-{1}".format(id, self.id_suffixes[collection_name]) def _parse_point_id(self, id: str) -> Tuple[str, str]: id, suffix = id.rsplit("-", 1) - for collection_name, suffix in ID_SUFFIXES.items(): + for collection_name, suffix in self.id_suffixes.items(): if type == suffix: return id, collection_name raise ValueError(f"Invalid id {id}")