Skip to content

Commit 91144d1

Browse files
committed
Fix the pylint error and add the chromedb test
Signed-off-by: SimFG <bang.fu@zilliz.com>
1 parent 94c8471 commit 91144d1

File tree

4 files changed

+190
-142
lines changed

4 files changed

+190
-142
lines changed

gptcache/adapter/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
23
from gptcache import cache
34
from gptcache.processor.post import temperature_softmax
45
from gptcache.utils.error import NotInitError
@@ -16,7 +17,6 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
1617
:param kwargs: llm kwargs
1718
:return: llm result
1819
"""
19-
health_check_flag = kwargs.pop("health_check", False)
2020
search_only_flag = kwargs.pop("search_only", False)
2121
user_temperature = "temperature" in kwargs
2222
user_top_k = "top_k" in kwargs
@@ -114,7 +114,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
114114
continue
115115

116116
# cache consistency check
117-
if health_check_flag:
117+
if chat_cache.config.data_check:
118118
is_healthy = cache_health_check(
119119
chat_cache.data_manager.v,
120120
{
@@ -202,7 +202,7 @@ def post_process():
202202
kwargs["cache_context"] = context
203203
kwargs["cache_skip"] = cache_skip
204204
kwargs["cache_factor"] = cache_factor
205-
kwargs["search_only_flag"] = search_only_flag
205+
kwargs["search_only"] = search_only_flag
206206
llm_data = adapt(
207207
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
208208
)
@@ -467,8 +467,8 @@ def update_cache_func(handled_llm_data, question=None):
467467
llm_data = update_cache_callback(
468468
llm_data, update_cache_func, *args, **kwargs
469469
)
470-
except Exception as e: # pylint: disable=W0703
471-
gptcache_log.warning("failed to save the data to cache, error: %s", e)
470+
except Exception: # pylint: disable=W0703
471+
gptcache_log.error("failed to save the data to cache", exc_info=True)
472472
return llm_data
473473

474474

gptcache/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def __init__(
4444
enable_token_counter: bool = True,
4545
input_summary_len: Optional[int] = None,
4646
context_len: Optional[int] = None,
47-
skip_list: List[str] = None
47+
skip_list: List[str] = None,
48+
data_check: bool = False,
4849
):
4950
if similarity_threshold < 0 or similarity_threshold > 1:
5051
raise CacheError(
@@ -61,3 +62,4 @@ def __init__(
6162
if skip_list is None:
6263
skip_list = ["system", "assistant"]
6364
self.skip_list = skip_list
65+
self.data_check = data_check
Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,44 @@
11
from typing import List, Optional
2+
23
import numpy as np
34

5+
from gptcache.manager.vector_data.base import VectorBase, VectorData
46
from gptcache.utils import import_qdrant
57
from gptcache.utils.log import gptcache_log
6-
from gptcache.manager.vector_data.base import VectorBase, VectorData
78

89
import_qdrant()
910

10-
from qdrant_client import QdrantClient # pylint: disable=C0413
11-
from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \
12-
Distance # pylint: disable=C0413
11+
# pylint: disable=C0413
12+
from qdrant_client import QdrantClient
13+
from qdrant_client.models import (
14+
PointStruct,
15+
HnswConfigDiff,
16+
VectorParams,
17+
OptimizersConfigDiff,
18+
Distance,
19+
)
1320

1421

1522
class QdrantVectorStore(VectorBase):
23+
"""Qdrant Vector Store"""
1624

1725
def __init__(
18-
self,
19-
url: Optional[str] = None,
20-
port: Optional[int] = 6333,
21-
grpc_port: int = 6334,
22-
prefer_grpc: bool = False,
23-
https: Optional[bool] = None,
24-
api_key: Optional[str] = None,
25-
prefix: Optional[str] = None,
26-
timeout: Optional[float] = None,
27-
host: Optional[str] = None,
28-
collection_name: Optional[str] = "gptcache",
29-
location: Optional[str] = "./qdrant",
30-
dimension: int = 0,
31-
top_k: int = 1,
32-
flush_interval_sec: int = 5,
33-
index_params: Optional[dict] = None,
26+
self,
27+
url: Optional[str] = None,
28+
port: Optional[int] = 6333,
29+
grpc_port: int = 6334,
30+
prefer_grpc: bool = False,
31+
https: Optional[bool] = None,
32+
api_key: Optional[str] = None,
33+
prefix: Optional[str] = None,
34+
timeout: Optional[float] = None,
35+
host: Optional[str] = None,
36+
collection_name: Optional[str] = "gptcache",
37+
location: Optional[str] = "./qdrant",
38+
dimension: int = 0,
39+
top_k: int = 1,
40+
flush_interval_sec: int = 5,
41+
index_params: Optional[dict] = None,
3442
):
3543
if dimension <= 0:
3644
raise ValueError(
@@ -44,13 +52,17 @@ def __init__(
4452
if self._in_memory or location is not None:
4553
self._create_local(location)
4654
else:
47-
self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https)
55+
self._create_remote(
56+
url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https
57+
)
4858
self._create_collection(collection_name, flush_interval_sec, index_params)
4959

5060
def _create_local(self, location):
5161
self._client = QdrantClient(location=location)
5262

53-
def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https):
63+
def _create_remote(
64+
self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https
65+
):
5466
self._client = QdrantClient(
5567
url=url,
5668
port=port,
@@ -63,45 +75,70 @@ def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_gr
6375
https=https,
6476
)
6577

66-
def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None):
78+
def _create_collection(
79+
self,
80+
collection_name: str,
81+
flush_interval_sec: int,
82+
index_params: Optional[dict] = None,
83+
):
6784
hnsw_config = HnswConfigDiff(**(index_params or {}))
68-
vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE,
69-
hnsw_config=hnsw_config)
70-
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000,
71-
flush_interval_sec=flush_interval_sec)
85+
vectors_config = VectorParams(
86+
size=self.dimension, distance=Distance.COSINE, hnsw_config=hnsw_config
87+
)
88+
optimizers_config = OptimizersConfigDiff(
89+
deleted_threshold=0.2,
90+
vacuum_min_vector_number=1000,
91+
flush_interval_sec=flush_interval_sec,
92+
)
7293
# check if the collection exists
7394
existing_collections = self._client.get_collections()
7495
for existing_collection in existing_collections.collections:
7596
if existing_collection.name == collection_name:
76-
gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name)
97+
gptcache_log.warning(
98+
"The %s collection already exists, and it will be used directly.",
99+
collection_name,
100+
)
77101
break
78102
else:
79-
self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config,
80-
optimizers_config=optimizers_config)
103+
self._client.create_collection(
104+
collection_name=collection_name,
105+
vectors_config=vectors_config,
106+
optimizers_config=optimizers_config,
107+
)
81108

82109
def mul_add(self, datas: List[VectorData]):
83-
points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas]
84-
self._client.upsert(collection_name=self._collection_name, points=points, wait=False)
110+
points = [
111+
PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas
112+
]
113+
self._client.upsert(
114+
collection_name=self._collection_name, points=points, wait=False
115+
)
85116

86117
def search(self, data: np.ndarray, top_k: int = -1):
87118
if top_k == -1:
88119
top_k = self.top_k
89120
reshaped_data = data.reshape(-1).tolist()
90-
search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data,
91-
limit=top_k)
121+
search_result = self._client.search(
122+
collection_name=self._collection_name,
123+
query_vector=reshaped_data,
124+
limit=top_k,
125+
)
92126
return list(map(lambda x: (x.score, x.id), search_result))
93127

94128
def delete(self, ids: List[str]):
95129
self._client.delete(collection_name=self._collection_name, points_selector=ids)
96130

97131
def rebuild(self, ids=None): # pylint: disable=unused-argument
98-
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000)
99-
self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config)
132+
optimizers_config = OptimizersConfigDiff(
133+
deleted_threshold=0.2, vacuum_min_vector_number=1000
134+
)
135+
self._client.update_collection(
136+
collection_name=self._collection_name, optimizer_config=optimizers_config
137+
)
100138

101139
def flush(self):
102140
# no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant
103141
pass
104142

105-
106143
def close(self):
107144
self.flush()

0 commit comments

Comments
 (0)