Skip to content

Commit 8cd3fa5

Browse files
committed
Fix the pylint error and add the chromedb test
Signed-off-by: SimFG <bang.fu@zilliz.com>
1 parent 94c8471 commit 8cd3fa5

File tree

5 files changed

+234
-141
lines changed

5 files changed

+234
-141
lines changed

examples/bug/discord.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import time
2+
3+
from langchain.chat_models import ChatOpenAI
4+
from langchain.schema import HumanMessage
5+
6+
from gptcache import Cache
7+
from gptcache.adapter.langchain_models import LangChainChat
8+
from gptcache.embedding import Onnx
9+
from gptcache.manager import CacheBase, manager_factory
10+
from gptcache.processor.pre import get_messages_last_content
11+
from gptcache.similarity_evaluation import SearchDistanceEvaluation
12+
13+
az_gpt = ChatOpenAI()
14+
15+
llm_cache = Cache()
16+
17+
onnx = Onnx()
18+
19+
cache_base = CacheBase('sqlite')
20+
data_manager = manager_factory("sqlite,faiss", data_dir="sqlite_faiss", scalar_params={}, vector_params={"dimension": onnx.dimension})
21+
llm_cache.init(
22+
pre_embedding_func=get_messages_last_content,
23+
data_manager=data_manager,
24+
embedding_func=onnx.to_embeddings,
25+
similarity_evaluation=SearchDistanceEvaluation()
26+
)
27+
28+
cached_chat = LangChainChat(chat=az_gpt)
29+
30+
conversation_history = []
31+
32+
while True or len(conversation_history) < 5:
33+
# Get user input
34+
user_input = input("Human: ")
35+
36+
conversation_history.append(user_input)
37+
human_message_prompt = [HumanMessage(content=user_input)]
38+
start_time = time.time()
39+
print(human_message_prompt)
40+
print(llm_cache)
41+
response = cached_chat(messages=human_message_prompt, cache_obj=llm_cache)
42+
end_time = time.time()
43+
44+
# Calculate the time taken
45+
time_taken = end_time - start_time
46+
print("Time taken:", time_taken, "seconds")
47+
48+
# Print the response
49+
print("AI:", response)

gptcache/adapter/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
23
from gptcache import cache
34
from gptcache.processor.post import temperature_softmax
45
from gptcache.utils.error import NotInitError
@@ -16,7 +17,6 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
1617
:param kwargs: llm kwargs
1718
:return: llm result
1819
"""
19-
health_check_flag = kwargs.pop("health_check", False)
2020
search_only_flag = kwargs.pop("search_only", False)
2121
user_temperature = "temperature" in kwargs
2222
user_top_k = "top_k" in kwargs
@@ -114,7 +114,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
114114
continue
115115

116116
# cache consistency check
117-
if health_check_flag:
117+
if chat_cache.config.data_check:
118118
is_healthy = cache_health_check(
119119
chat_cache.data_manager.v,
120120
{
@@ -202,7 +202,7 @@ def post_process():
202202
kwargs["cache_context"] = context
203203
kwargs["cache_skip"] = cache_skip
204204
kwargs["cache_factor"] = cache_factor
205-
kwargs["search_only_flag"] = search_only_flag
205+
kwargs["search_only"] = search_only_flag
206206
llm_data = adapt(
207207
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
208208
)
@@ -467,8 +467,8 @@ def update_cache_func(handled_llm_data, question=None):
467467
llm_data = update_cache_callback(
468468
llm_data, update_cache_func, *args, **kwargs
469469
)
470-
except Exception as e: # pylint: disable=W0703
471-
gptcache_log.warning("failed to save the data to cache, error: %s", e)
470+
except Exception: # pylint: disable=W0703
471+
gptcache_log.error("failed to save the data to cache", exc_info=True)
472472
return llm_data
473473

474474

gptcache/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def __init__(
4444
enable_token_counter: bool = True,
4545
input_summary_len: Optional[int] = None,
4646
context_len: Optional[int] = None,
47-
skip_list: List[str] = None
47+
skip_list: List[str] = None,
48+
data_check: bool = False,
4849
):
4950
if similarity_threshold < 0 or similarity_threshold > 1:
5051
raise CacheError(
@@ -61,3 +62,4 @@ def __init__(
6162
if skip_list is None:
6263
skip_list = ["system", "assistant"]
6364
self.skip_list = skip_list
65+
self.data_check = data_check

gptcache/manager/vector_data/qdrant.py

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,33 @@
88
import_qdrant()
99

1010
from qdrant_client import QdrantClient # pylint: disable=C0413
11-
from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \
12-
Distance # pylint: disable=C0413
11+
from qdrant_client.models import (
12+
PointStruct,
13+
HnswConfigDiff,
14+
VectorParams,
15+
OptimizersConfigDiff,
16+
Distance,
17+
) # pylint: disable=C0413
1318

1419

1520
class QdrantVectorStore(VectorBase):
16-
1721
def __init__(
18-
self,
19-
url: Optional[str] = None,
20-
port: Optional[int] = 6333,
21-
grpc_port: int = 6334,
22-
prefer_grpc: bool = False,
23-
https: Optional[bool] = None,
24-
api_key: Optional[str] = None,
25-
prefix: Optional[str] = None,
26-
timeout: Optional[float] = None,
27-
host: Optional[str] = None,
28-
collection_name: Optional[str] = "gptcache",
29-
location: Optional[str] = "./qdrant",
30-
dimension: int = 0,
31-
top_k: int = 1,
32-
flush_interval_sec: int = 5,
33-
index_params: Optional[dict] = None,
22+
self,
23+
url: Optional[str] = None,
24+
port: Optional[int] = 6333,
25+
grpc_port: int = 6334,
26+
prefer_grpc: bool = False,
27+
https: Optional[bool] = None,
28+
api_key: Optional[str] = None,
29+
prefix: Optional[str] = None,
30+
timeout: Optional[float] = None,
31+
host: Optional[str] = None,
32+
collection_name: Optional[str] = "gptcache",
33+
location: Optional[str] = "./qdrant",
34+
dimension: int = 0,
35+
top_k: int = 1,
36+
flush_interval_sec: int = 5,
37+
index_params: Optional[dict] = None,
3438
):
3539
if dimension <= 0:
3640
raise ValueError(
@@ -44,13 +48,17 @@ def __init__(
4448
if self._in_memory or location is not None:
4549
self._create_local(location)
4650
else:
47-
self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https)
51+
self._create_remote(
52+
url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https
53+
)
4854
self._create_collection(collection_name, flush_interval_sec, index_params)
4955

5056
def _create_local(self, location):
5157
self._client = QdrantClient(location=location)
5258

53-
def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https):
59+
def _create_remote(
60+
self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https
61+
):
5462
self._client = QdrantClient(
5563
url=url,
5664
port=port,
@@ -63,45 +71,70 @@ def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_gr
6371
https=https,
6472
)
6573

66-
def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None):
74+
def _create_collection(
75+
self,
76+
collection_name: str,
77+
flush_interval_sec: int,
78+
index_params: Optional[dict] = None,
79+
):
6780
hnsw_config = HnswConfigDiff(**(index_params or {}))
68-
vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE,
69-
hnsw_config=hnsw_config)
70-
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000,
71-
flush_interval_sec=flush_interval_sec)
81+
vectors_config = VectorParams(
82+
size=self.dimension, distance=Distance.COSINE, hnsw_config=hnsw_config
83+
)
84+
optimizers_config = OptimizersConfigDiff(
85+
deleted_threshold=0.2,
86+
vacuum_min_vector_number=1000,
87+
flush_interval_sec=flush_interval_sec,
88+
)
7289
# check if the collection exists
7390
existing_collections = self._client.get_collections()
7491
for existing_collection in existing_collections.collections:
7592
if existing_collection.name == collection_name:
76-
gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name)
93+
gptcache_log.warning(
94+
"The %s collection already exists, and it will be used directly.",
95+
collection_name,
96+
)
7797
break
7898
else:
79-
self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config,
80-
optimizers_config=optimizers_config)
99+
self._client.create_collection(
100+
collection_name=collection_name,
101+
vectors_config=vectors_config,
102+
optimizers_config=optimizers_config,
103+
)
81104

82105
def mul_add(self, datas: List[VectorData]):
83-
points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas]
84-
self._client.upsert(collection_name=self._collection_name, points=points, wait=False)
106+
points = [
107+
PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas
108+
]
109+
self._client.upsert(
110+
collection_name=self._collection_name, points=points, wait=False
111+
)
85112

86113
def search(self, data: np.ndarray, top_k: int = -1):
87114
if top_k == -1:
88115
top_k = self.top_k
89116
reshaped_data = data.reshape(-1).tolist()
90-
search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data,
91-
limit=top_k)
117+
search_result = self._client.search(
118+
collection_name=self._collection_name,
119+
query_vector=reshaped_data,
120+
limit=top_k,
121+
)
92122
return list(map(lambda x: (x.score, x.id), search_result))
93123

94124
def delete(self, ids: List[str]):
95125
self._client.delete(collection_name=self._collection_name, points_selector=ids)
96126

97127
def rebuild(self, ids=None): # pylint: disable=unused-argument
98-
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000)
99-
self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config)
128+
optimizers_config = OptimizersConfigDiff(
129+
deleted_threshold=0.2, vacuum_min_vector_number=1000
130+
)
131+
self._client.update_collection(
132+
collection_name=self._collection_name, optimizer_config=optimizers_config
133+
)
100134

101135
def flush(self):
102136
# no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant
103137
pass
104138

105-
106139
def close(self):
107140
self.flush()

0 commit comments

Comments
 (0)