Skip to content

Commit 94c8471

Browse files
parthvnpSimFG
authored andcommitted
Add support for Qdrant Vector Store (#453)
* Add Qdrant vector store client * Add setup for Qdrant vector store * Add lazy import for Qdrant client library * Add import_qdrant to index file * Use models not types for building configs * Add tests for Qdrant vector store
1 parent 8029262 commit 94c8471

File tree

4 files changed

+186
-1
lines changed

4 files changed

+186
-1
lines changed

gptcache/manager/vector_data/manager.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
PGVECTOR_URL = "postgresql://postgres:postgres@localhost:5432/postgres"
2020
PGVECTOR_INDEX_PARAMS = {"index_type": "L2", "params": {"lists": 100, "probes": 10}}
2121

22+
QDRANT_GRPC_PORT = 6334
23+
QDRANT_HTTP_PORT = 6333
24+
QDRANT_INDEX_PARAMS = {"ef_construct": 100, "m": 16}
25+
QDRANT_DEFAULT_LOCATION = "./qdrant_data"
26+
QDRANT_FLUSH_INTERVAL_SEC = 5
27+
2228
COLLECTION_NAME = "gptcache"
2329

2430

@@ -217,6 +223,40 @@ def get(name, **kwargs):
217223
collection_name=collection_name,
218224
top_k=top_k,
219225
)
226+
elif name == "qdrant":
227+
from gptcache.manager.vector_data.qdrant import QdrantVectorStore
228+
url = kwargs.get("url", None)
229+
port = kwargs.get("port", QDRANT_HTTP_PORT)
230+
grpc_port = kwargs.get("grpc_port", QDRANT_GRPC_PORT)
231+
prefer_grpc = kwargs.get("prefer_grpc", False)
232+
https = kwargs.get("https", False)
233+
api_key = kwargs.get("api_key", None)
234+
prefix = kwargs.get("prefix", None)
235+
timeout = kwargs.get("timeout", None)
236+
host = kwargs.get("host", None)
237+
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
238+
location = kwargs.get("location", QDRANT_DEFAULT_LOCATION)
239+
dimension = kwargs.get("dimension", DIMENSION)
240+
top_k: int = kwargs.get("top_k", TOP_K)
241+
flush_interval_sec = kwargs.get("flush_interval_sec", QDRANT_FLUSH_INTERVAL_SEC)
242+
index_params = kwargs.get("index_params", QDRANT_INDEX_PARAMS)
243+
vector_base = QdrantVectorStore(
244+
url=url,
245+
port=port,
246+
grpc_port=grpc_port,
247+
prefer_grpc=prefer_grpc,
248+
https=https,
249+
api_key=api_key,
250+
prefix=prefix,
251+
timeout=timeout,
252+
host=host,
253+
collection_name=collection_name,
254+
location=location,
255+
dimension=dimension,
256+
top_k=top_k,
257+
flush_interval_sec=flush_interval_sec,
258+
index_params=index_params,
259+
)
220260
else:
221261
raise NotFoundError("vector store", name)
222262
return vector_base
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import List, Optional
2+
import numpy as np
3+
4+
from gptcache.utils import import_qdrant
5+
from gptcache.utils.log import gptcache_log
6+
from gptcache.manager.vector_data.base import VectorBase, VectorData
7+
8+
import_qdrant()
9+
10+
from qdrant_client import QdrantClient # pylint: disable=C0413
11+
from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \
12+
Distance # pylint: disable=C0413
13+
14+
15+
class QdrantVectorStore(VectorBase):
16+
17+
def __init__(
18+
self,
19+
url: Optional[str] = None,
20+
port: Optional[int] = 6333,
21+
grpc_port: int = 6334,
22+
prefer_grpc: bool = False,
23+
https: Optional[bool] = None,
24+
api_key: Optional[str] = None,
25+
prefix: Optional[str] = None,
26+
timeout: Optional[float] = None,
27+
host: Optional[str] = None,
28+
collection_name: Optional[str] = "gptcache",
29+
location: Optional[str] = "./qdrant",
30+
dimension: int = 0,
31+
top_k: int = 1,
32+
flush_interval_sec: int = 5,
33+
index_params: Optional[dict] = None,
34+
):
35+
if dimension <= 0:
36+
raise ValueError(
37+
f"invalid `dim` param: {dimension} in the Qdrant vector store."
38+
)
39+
self._client: QdrantClient
40+
self._collection_name = collection_name
41+
self._in_memory = location == ":memory:"
42+
self.dimension = dimension
43+
self.top_k = top_k
44+
if self._in_memory or location is not None:
45+
self._create_local(location)
46+
else:
47+
self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https)
48+
self._create_collection(collection_name, flush_interval_sec, index_params)
49+
50+
def _create_local(self, location):
51+
self._client = QdrantClient(location=location)
52+
53+
def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https):
54+
self._client = QdrantClient(
55+
url=url,
56+
port=port,
57+
api_key=api_key,
58+
timeout=timeout,
59+
host=host,
60+
grpc_port=grpc_port,
61+
prefer_grpc=prefer_grpc,
62+
prefix=prefix,
63+
https=https,
64+
)
65+
66+
def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None):
67+
hnsw_config = HnswConfigDiff(**(index_params or {}))
68+
vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE,
69+
hnsw_config=hnsw_config)
70+
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000,
71+
flush_interval_sec=flush_interval_sec)
72+
# check if the collection exists
73+
existing_collections = self._client.get_collections()
74+
for existing_collection in existing_collections.collections:
75+
if existing_collection.name == collection_name:
76+
gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name)
77+
break
78+
else:
79+
self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config,
80+
optimizers_config=optimizers_config)
81+
82+
def mul_add(self, datas: List[VectorData]):
83+
points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas]
84+
self._client.upsert(collection_name=self._collection_name, points=points, wait=False)
85+
86+
def search(self, data: np.ndarray, top_k: int = -1):
87+
if top_k == -1:
88+
top_k = self.top_k
89+
reshaped_data = data.reshape(-1).tolist()
90+
search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data,
91+
limit=top_k)
92+
return list(map(lambda x: (x.score, x.id), search_result))
93+
94+
def delete(self, ids: List[str]):
95+
self._client.delete(collection_name=self._collection_name, points_selector=ids)
96+
97+
def rebuild(self, ids=None): # pylint: disable=unused-argument
98+
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000)
99+
self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config)
100+
101+
def flush(self):
102+
# no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant
103+
pass
104+
105+
106+
def close(self):
107+
self.flush()

gptcache/utils/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
"import_paddlenlp",
3939
"import_tiktoken",
4040
"import_fastapi",
41-
"import_redis"
41+
"import_redis",
42+
"import_qdrant"
4243
]
4344

4445
import importlib.util
@@ -65,6 +66,10 @@ def import_milvus_lite():
6566
_check_library("milvus")
6667

6768

69+
def import_qdrant():
70+
_check_library("qdrant_client")
71+
72+
6873
def import_sbert():
6974
_check_library("sentence_transformers", package="sentence-transformers")
7075

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import unittest
3+
4+
import numpy as np
5+
6+
from gptcache.manager.vector_data import VectorBase
7+
from gptcache.manager.vector_data.base import VectorData
8+
9+
10+
class TestQdrant(unittest.TestCase):
11+
def test_normal(self):
12+
size = 10
13+
dim = 2
14+
top_k = 10
15+
qdrant = VectorBase(
16+
"qdrant",
17+
top_k=top_k,
18+
dimension=dim,
19+
location=":memory:"
20+
)
21+
data = np.random.randn(size, dim).astype(np.float32)
22+
qdrant.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
23+
search_result = qdrant.search(data[0], top_k)
24+
self.assertEqual(len(search_result), top_k)
25+
qdrant.mul_add([VectorData(id=size, data=data[0])])
26+
ret = qdrant.search(data[0])
27+
self.assertIn(ret[0][1], [0, size])
28+
self.assertIn(ret[1][1], [0, size])
29+
qdrant.delete([0, 1, 2, 3, 4, 5, size])
30+
ret = qdrant.search(data[0])
31+
self.assertNotIn(ret[0][1], [0, size])
32+
qdrant.rebuild()
33+
qdrant.close()

0 commit comments

Comments
 (0)