Skip to content

Commit

Permalink
Merge pull request vanna-ai#407 from Anush008/config-collection-name
Browse files Browse the repository at this point in the history
refactor: Configurable collection names in Qdrant_VectorStore
  • Loading branch information
zainhoda authored May 6, 2024
2 parents c23620a + 8b022f3 commit e7f1a12
Showing 1 changed file with 71 additions and 57 deletions.
128 changes: 71 additions & 57 deletions src/vanna/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,51 @@
from ..base import VannaBase
from ..utils import deterministic_uuid

DOCUMENTATION_COLLECTION_NAME = "documentation"
DDL_COLLECTION_NAME = "ddl"
SQL_COLLECTION_NAME = "sql"
SCROLL_SIZE = 1000

ID_SUFFIXES = {
DDL_COLLECTION_NAME: "ddl",
DOCUMENTATION_COLLECTION_NAME: "doc",
SQL_COLLECTION_NAME: "sql",
}


class Qdrant_VectorStore(VannaBase):
"""Vectorstore implementation using Qdrant - https://qdrant.tech/"""
"""
Vectorstore implementation using Qdrant - https://qdrant.tech/
Args:
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
- path: Persistence path for QdrantLocal. Default: `None`.
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
- n_results: Number of results to return from similarity search. Defaults to 10.
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
Defaults to `"BAAI/bge-small-en-v1.5"`.
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
- documentation_collection_name: Name of the collection to store documentation. Defaults to `"documentation"`.
- ddl_collection_name: Name of the collection to store DDL. Defaults to `"ddl"`.
- sql_collection_name: Name of the collection to store SQL. Defaults to `"sql"`.
Raises:
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
"""

documentation_collection_name = "documentation"
ddl_collection_name = "ddl"
sql_collection_name = "sql"

id_suffixes = {
ddl_collection_name: "ddl",
documentation_collection_name: "doc",
sql_collection_name: "sql",
}

def __init__(
self,
config={},
):
"""
Vectorstore implementation using Qdrant - https://qdrant.tech/
Args:
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
- path: Persistence path for QdrantLocal. Default: `None`.
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
- n_results: Number of results to return from similarity search. Defaults to 10.
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
Defaults to `"BAAI/bge-small-en-v1.5"`.
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
Raises:
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
"""
VannaBase.__init__(self, config=config)
client = config.get("client")

Expand All @@ -75,6 +78,15 @@ def __init__(
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
self.collection_params = config.get("collection_params", {})
self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
self.documentation_collection_name = config.get(
"documentation_collection_name", self.documentation_collection_name
)
self.ddl_collection_name = config.get(
"ddl_collection_name", self.ddl_collection_name
)
self.sql_collection_name = config.get(
"sql_collection_name", self.sql_collection_name
)

self._setup_collections()

Expand All @@ -83,7 +95,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
id = deterministic_uuid(question_answer)

self._client.upsert(
SQL_COLLECTION_NAME,
self.sql_collection_name,
points=[
models.PointStruct(
id=id,
Expand All @@ -96,12 +108,12 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
],
)

return self._format_point_id(id, SQL_COLLECTION_NAME)
return self._format_point_id(id, self.sql_collection_name)

def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl)
self._client.upsert(
DDL_COLLECTION_NAME,
self.ddl_collection_name,
points=[
models.PointStruct(
id=id,
Expand All @@ -112,13 +124,13 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
)
],
)
return self._format_point_id(id, DDL_COLLECTION_NAME)
return self._format_point_id(id, self.ddl_collection_name)

def add_documentation(self, documentation: str, **kwargs) -> str:
id = deterministic_uuid(documentation)

self._client.upsert(
DOCUMENTATION_COLLECTION_NAME,
self.documentation_collection_name,
points=[
models.PointStruct(
id=id,
Expand All @@ -130,16 +142,17 @@ def add_documentation(self, documentation: str, **kwargs) -> str:
],
)

return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)
return self._format_point_id(id, self.documentation_collection_name)

def get_training_data(self, **kwargs) -> pd.DataFrame:
df = pd.DataFrame()

if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
if sql_data := self._get_all_points(self.sql_collection_name):
question_list = [data.payload["question"] for data in sql_data]
sql_list = [data.payload["sql"] for data in sql_data]
id_list = [
self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
self._format_point_id(data.id, self.sql_collection_name)
for data in sql_data
]

df_sql = pd.DataFrame(
Expand All @@ -154,10 +167,11 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

df = pd.concat([df, df_sql])

if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
if ddl_data := self._get_all_points(self.ddl_collection_name):
ddl_list = [data.payload["ddl"] for data in ddl_data]
id_list = [
self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in ddl_data
self._format_point_id(data.id, self.ddl_collection_name)
for data in ddl_data
]

df_ddl = pd.DataFrame(
Expand All @@ -172,10 +186,10 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

df = pd.concat([df, df_ddl])

if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
if doc_data := self._get_all_points(self.documentation_collection_name):
document_list = [data.payload["documentation"] for data in doc_data]
id_list = [
self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
self._format_point_id(data.id, self.documentation_collection_name)
for data in doc_data
]

Expand Down Expand Up @@ -210,7 +224,7 @@ def remove_collection(self, collection_name: str) -> bool:
Returns:
bool: True if collection is deleted, False otherwise
"""
if collection_name in ID_SUFFIXES.keys():
if collection_name in self.id_suffixes.keys():
self._client.delete_collection(collection_name)
self._setup_collections()
return True
Expand All @@ -223,7 +237,7 @@ def embeddings_dimension(self):

def get_similar_question_sql(self, question: str, **kwargs) -> list:
results = self._client.search(
SQL_COLLECTION_NAME,
self.sql_collection_name,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
Expand All @@ -233,7 +247,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list:

def get_related_ddl(self, question: str, **kwargs) -> list:
results = self._client.search(
DDL_COLLECTION_NAME,
self.ddl_collection_name,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
Expand All @@ -243,7 +257,7 @@ def get_related_ddl(self, question: str, **kwargs) -> list:

def get_related_documentation(self, question: str, **kwargs) -> list:
results = self._client.search(
DOCUMENTATION_COLLECTION_NAME,
self.documentation_collection_name,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
Expand Down Expand Up @@ -282,28 +296,28 @@ def _get_all_points(self, collection_name: str):
return results

def _setup_collections(self):
if not self._client.collection_exists(SQL_COLLECTION_NAME):
if not self._client.collection_exists(self.sql_collection_name):
self._client.create_collection(
collection_name=SQL_COLLECTION_NAME,
collection_name=self.sql_collection_name,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)

if not self._client.collection_exists(DDL_COLLECTION_NAME):
if not self._client.collection_exists(self.ddl_collection_name):
self._client.create_collection(
collection_name=DDL_COLLECTION_NAME,
collection_name=self.ddl_collection_name,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)
if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
if not self._client.collection_exists(self.documentation_collection_name):
self._client.create_collection(
collection_name=DOCUMENTATION_COLLECTION_NAME,
collection_name=self.documentation_collection_name,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
Expand All @@ -312,11 +326,11 @@ def _setup_collections(self):
)

def _format_point_id(self, id: str, collection_name: str) -> str:
return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])
return "{0}-{1}".format(id, self.id_suffixes[collection_name])

def _parse_point_id(self, id: str) -> Tuple[str, str]:
id, suffix = id.rsplit("-", 1)
for collection_name, suffix in ID_SUFFIXES.items():
for collection_name, suffix in self.id_suffixes.items():
if type == suffix:
return id, collection_name
raise ValueError(f"Invalid id {id}")

0 comments on commit e7f1a12

Please sign in to comment.