Skip to content
2 changes: 1 addition & 1 deletion gptcache/manager/scalar_data/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,4 @@ def report_cache(

def close(self):
me.disconnect()
self.con.close()
self.con.close()
26 changes: 26 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,32 @@ def get(name, **kwargs):
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
elif name == "weaviate":
from gptcache.manager.vector_data.weaviate import Weaviate
url = kwargs.get("url", None)
auth_client_secret = kwargs.get('auth_client_secret', None),
timeout_config = kwargs.get("timeout_config", (10, 60))
proxies = kwargs.get("proxies", None)
trust_env = kwargs.get("trust_env", False)
additional_headers = kwargs.get("additional_headers", None)
startup_period = kwargs.get("startup_period", 5)
embedded_options = kwargs.get("embedded_options", None)
additional_config = kwargs.get("additional_config", None)
class_name = kwargs.get("class_name", "Gptcache")
top_k = kwargs.get("top_k", 1)
vector_base = Weaviate(
url= url,
auth_client_secret = auth_client_secret,
timeout_config = timeout_config,
proxies = proxies,
trust_env = trust_env,
additional_headers = additional_headers,
startup_period = startup_period,
embedded_options = embedded_options,
additional_config = additional_config,
class_name = class_name,
top_k = top_k,
)
else:
raise NotFoundError("vector store", name)
return vector_base
131 changes: 131 additions & 0 deletions gptcache/manager/vector_data/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import List, Optional, Union

import numpy as np

from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_weaviate
from gptcache.utils.log import gptcache_log

import_weaviate()

from weaviate import Client, EmbeddedOptions, Config


class Weaviate(VectorBase):
"""Weaviate Vector store"""
def __init__(
self,
url: str = None,
auth_client_secret = None,
timeout_config = (10, 60),
proxies: Optional[Union[dict, str]] = None,
trust_env: bool = False,
additional_headers: Optional[dict] = None,
startup_period: Optional[int] = 5,
embedded_options = None,
additional_config = None,
top_k: int = 1,
distance: str = "cosine",
class_name: str = "Gptcache",
):
self.class_name = class_name
self.top_k = top_k
self.distance = distance
if not url:
self.client = Client(
embedded_options = EmbeddedOptions(),
startup_period = startup_period,
timeout_config = timeout_config,
additional_config = additional_config
)
else:
self.client = Client(
url,
auth_client_secret,
timeout_config,
proxies,
trust_env,
additional_headers,
startup_period,
embedded_options,
additional_config,
)

def _create_collection(self, class_name: str):
if not class_name:
class_name = self.class_name
if self.client.schema.exists(class_name):
gptcache_log.info(
"The %s already exists, and it will be used directly", class_name
)
else:
gptcache_class_schema = {
"class": class_name,
"description": "caching LLM responses",
"properties": [
{
"name": "id_",
"dataType": ["int"],
}
],
'vectorIndexConfig':
{
"distance": self.distance
}
}
self.client.schema.create_class(gptcache_class_schema)

def mul_add(self, datas: List[VectorData]):
with self.client.batch(
batch_size=len(datas)
) as batch:
# Batch import
for data in datas:
properties = {
"id_": data.id,
}
self.client.batch.add_data_object(
properties,
self.class_name,
vector = data.data.tolist()
)

def search(self, data: np.ndarray, top_k: int = -1):
if not self.client.schema.exists(self.class_name):
self._create_collection(self.class_name)
if top_k==-1:
top_k = self.top_k
result = self.client.query.get(class_name = self.class_name, properties = ['id_']).\
with_near_vector(content={"vector": data.tolist()}).\
with_additional(['distance']).\
with_limit(top_k).do()
return list(map(lambda x: (x['_additional']['distance'], x['id_']), result['data']['Get'][self.class_name]))

def get_uuids(self, ids: List[str]):
uuid_list = []
for id_ in ids:
res = self.client.query.get(class_name=self.class_name, properties=['id_']).\
with_where({"path": ["id_"], "operator":"Equal", "valueNumber":id_}).\
with_additional(["id"]).do()
uuid_list.append(res['data']['Get'][self.class_name][0]['_additional']['id'])
return uuid_list

def delete(self, ids: List[str]):
uuids = self.get_uuids(ids)
for uuid_ in uuids:
self.client.data_object.delete(class_name = self.class_name, uuid=uuid_)

def rebuild(self, ids=None) :
return

def flush(self):
return True

def close(self):
pass






8 changes: 7 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"import_fastapi",
"import_redis",
"import_qdrant",
"import_weaviate"
]

import importlib.util
Expand Down Expand Up @@ -116,7 +117,7 @@ def import_hnswlib():


def import_chromadb():
_check_library("chromadb")
_check_library("chromadb", package="chromadb==0.3.26")


def import_sqlalchemy():
Expand Down Expand Up @@ -260,5 +261,10 @@ def import_redis():
_check_library("redis_om")


def import_weaviate():
_check_library("weaviate-client")


def import_starlette():
_check_library("starlette")

2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ grpcio==1.53.0
protobuf==3.20.0
milvus==2.2.8
pymilvus==2.2.8
pymongo
mongoengine
30 changes: 30 additions & 0 deletions tests/unit_tests/manager/test_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest

import numpy as np

from gptcache.manager.vector_data import VectorBase
from gptcache.manager.vector_data.base import VectorData


class TestUSearchDB(unittest.TestCase):
def test_normal(self):
size = 1000
dim = 512
top_k = 10
weaviate = VectorBase(
"weaviate",
top_k = top_k
)
data = np.random.randn(size, dim).astype(np.float32)
weaviate.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
search_result = weaviate.search(data[0], top_k)
self.assertEqual(len(search_result), top_k)
weaviate.mul_add([VectorData(id=size, data=data[0])])
ret = weaviate.search(data[0])
self.assertIn(ret[0][1], [0, size])
self.assertIn(ret[1][1], [0, size])
weaviate.delete([0, 1, 2, 3, 4, 5, size])
ret = weaviate.search(data[0])
self.assertNotIn(ret[0][1], [0, size])
weaviate.rebuild()
weaviate.close()