-
Notifications
You must be signed in to change notification settings - Fork 44.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/weaviate memory #424
Changes from 9 commits
986d32c
da4ba3c
1e63bc5
0ce0c55
97ac802
76a1462
5fe784a
786ee60
3c7767f
96c5e92
453b428
75c4132
f2a6ac5
e3aea6d
67b84b5
b9a4f97
415c1cb
35ecd95
b7d0cc3
5308946
5592dbd
855de18
067e697
2f8cf68
0c3562f
a94b93b
4c7deef
b987cff
005be02
b2bfd39
2678a5a
8916b76
899c815
5122422
03d2032
23b89b8
4cd412c
37a1dc1
b865e2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,6 +204,22 @@ export PINECONE_ENV="Your pinecone region" # something like: us-east4-gcp | |
|
||
``` | ||
|
||
## Weaviate Setup | ||
|
||
[Weaviate](https://weaviate.io/) is an open-source vector database. It allows to store data objects and vector embeddings from ML-models and scales seamlessly to billion of data objects. [An instance of Weaviate can be created locally (using Docker), on Kubernetes or using Weaviate Cloud Services](https://weaviate.io/developers/weaviate/quickstart). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cs0lar should also mention embedded weaviate here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good spot, thanks! This is now done. |
||
|
||
#### Setting up enviornment variables | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Followed the instructions and ... it did not work 😂 Because it is missing the weaviate client. So, perhaps edit to say something like "First, run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hsm207 HA! I had assumed |
||
|
||
In your `.env` file set the following: | ||
|
||
``` | ||
MEMORY_BACKEND=weaviate | ||
WEAVIATE_HOST="http://127.0.0.1" # the URL of the running Weaviate instance | ||
WEAVIATE_PORT="8080" | ||
WEAVIATE_USERNAME="your username" | ||
WEAVIATE_PASSWORD="your password" | ||
MEMORY_INDEX="Autogpt" # name of the index to create for the application | ||
``` | ||
|
||
## View Memory Usage | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ pinecone-client==2.2.1 | |
redis | ||
orjson | ||
Pillow | ||
weaviate-client==3.15.4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this |
||
import pinecone | ||
|
||
from memory.base import MemoryProviderSingleton, get_ada_embedding | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,100 @@ | ||||||||
from config import Config | ||||||||
from memory.base import MemoryProviderSingleton, get_ada_embedding | ||||||||
import uuid | ||||||||
import weaviate | ||||||||
from weaviate import Client | ||||||||
from weaviate.util import generate_uuid5 | ||||||||
|
||||||||
def default_schema(weaviate_index): | ||||||||
return { | ||||||||
"class": weaviate_index, | ||||||||
"properties": [ | ||||||||
{ | ||||||||
"name": "raw_text", | ||||||||
"dataType": ["text"], | ||||||||
"description": "original text for the embedding" | ||||||||
} | ||||||||
], | ||||||||
} | ||||||||
|
||||||||
class WeaviateMemory(MemoryProviderSingleton): | ||||||||
def __init__(self, cfg): | ||||||||
auth_credentials = self._build_auth_credentials(cfg) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cs0lar I think adding support for API key would be immensely useful for users getting started using weaviate's free sandbox environment, so they can have a similar experience with pinecone i.e. just provide an api key and url: https://weaviate.io/developers/weaviate/client-libraries/python#api-key-authentication There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||||||||
|
||||||||
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}' | ||||||||
|
||||||||
self.client = Client(url, auth_client_secret=auth_credentials) | ||||||||
self.index = cfg.memory_index | ||||||||
self._create_schema() | ||||||||
|
||||||||
def _create_schema(self): | ||||||||
schema = default_schema(self.index) | ||||||||
if not self.client.schema.contains(schema): | ||||||||
self.client.schema.create_class(schema) | ||||||||
|
||||||||
def _build_auth_credentials(self, cfg): | ||||||||
if cfg.weaviate_username and cfg.weaviate_password: | ||||||||
return weaviate_auth.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cs0lar this line will throw an error as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cs0lar I think you missed this change |
||||||||
else: | ||||||||
return None | ||||||||
|
||||||||
def add(self, data): | ||||||||
vector = get_ada_embedding(data) | ||||||||
|
||||||||
doc_uuid = generate_uuid5(data, self.index) | ||||||||
data_object = { | ||||||||
'class': self.index, | ||||||||
'raw_text': data | ||||||||
} | ||||||||
|
||||||||
with self.client.batch as batch: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cs0lar is data always going to be uploaded one at a time from auto gpt to weaviate? |
||||||||
batch.add_data_object( | ||||||||
uuid=doc_uuid, | ||||||||
data_object=data_object, | ||||||||
class_name=self.index, | ||||||||
vector=vector | ||||||||
) | ||||||||
|
||||||||
batch.flush() | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cs0lar this call is unnecessary since you are using the batch context manager |
||||||||
|
||||||||
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" | ||||||||
|
||||||||
|
||||||||
def get(self, data): | ||||||||
return self.get_relevant(data, 1) | ||||||||
|
||||||||
|
||||||||
def clear(self): | ||||||||
self.client.schema.delete_all() | ||||||||
|
||||||||
# weaviate does not yet have a neat way to just remove the items in an index | ||||||||
# without removing the entire schema, therefore we need to re-create it | ||||||||
# after a call to delete_all | ||||||||
self._create_schema() | ||||||||
|
||||||||
return 'Obliterated' | ||||||||
|
||||||||
def get_relevant(self, data, num_relevant=5): | ||||||||
query_embedding = get_ada_embedding(data) | ||||||||
try: | ||||||||
results = self.client.query.get(self.index, ['raw_text']) \ | ||||||||
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \ | ||||||||
.with_limit(num_relevant) \ | ||||||||
.do() | ||||||||
|
||||||||
if len(results['data']['Get'][self.index]) > 0: | ||||||||
return [str(item['raw_text']) for item in results['data']['Get'][self.index]] | ||||||||
else: | ||||||||
return [] | ||||||||
|
||||||||
except Exception as err: | ||||||||
print(f'Unexpected error {err=}, {type(err)=}') | ||||||||
return [] | ||||||||
|
||||||||
def get_stats(self): | ||||||||
result = self.client.query.aggregate(self.index) \ | ||||||||
.with_meta_count() \ | ||||||||
.do() | ||||||||
class_data = result['data']['Aggregate'][self.index] | ||||||||
|
||||||||
return class_data[0]['meta'] if class_data else {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import unittest | ||
from unittest import mock | ||
import sys | ||
import os | ||
|
||
from weaviate import Client | ||
from weaviate.util import get_valid_uuid | ||
from uuid import uuid4 | ||
|
||
sys.path.append(os.path.abspath('./scripts')) | ||
from config import Config | ||
from memory.weaviate import WeaviateMemory | ||
from memory.base import get_ada_embedding | ||
|
||
@mock.patch.dict(os.environ, { | ||
"WEAVIATE_HOST": "http://127.0.0.1", | ||
"WEAVIATE_PORT": "8080", | ||
"WEAVIATE_USERNAME": '', | ||
"WEAVIATE_PASSWORD": '', | ||
"MEMORY_INDEX": "AutogptTests" | ||
}) | ||
class TestWeaviateMemory(unittest.TestCase): | ||
""" | ||
In order to run these tests you will need a local instance of | ||
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose | ||
for creating local instances using docker. | ||
""" | ||
def setUp(self): | ||
self.cfg = Config() | ||
|
||
self.client = Client('http://127.0.0.1:8080') | ||
|
||
try: | ||
self.client.schema.delete_class(self.cfg.memory_index) | ||
except: | ||
pass | ||
|
||
self.memory = WeaviateMemory(self.cfg) | ||
|
||
def test_add(self): | ||
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones' | ||
self.memory.add(doc) | ||
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do() | ||
actual = result['data']['Get'][self.cfg.memory_index] | ||
|
||
self.assertEqual(len(actual), 1) | ||
self.assertEqual(actual[0]['raw_text'], doc) | ||
|
||
def test_get(self): | ||
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos' | ||
|
||
with self.client.batch as batch: | ||
batch.add_data_object( | ||
uuid=get_valid_uuid(uuid4()), | ||
data_object={'raw_text': doc}, | ||
class_name=self.cfg.memory_index, | ||
vector=get_ada_embedding(doc) | ||
) | ||
|
||
batch.flush() | ||
|
||
actual = self.memory.get(doc) | ||
|
||
self.assertEqual(len(actual), 1) | ||
self.assertEqual(actual[0], doc) | ||
|
||
|
||
def test_get_stats(self): | ||
docs = [ | ||
'You are now about to count the number of docs in this index', | ||
'And then you about to find out if you can count correctly' | ||
] | ||
|
||
[self.memory.add(doc) for doc in docs] | ||
|
||
stats = self.memory.get_stats() | ||
|
||
self.assertTrue(stats) | ||
self.assertTrue('count' in stats) | ||
self.assertEqual(stats['count'], 2) | ||
|
||
|
||
def test_clear(self): | ||
docs = [ | ||
'Shame this is the last test for this class', | ||
'Testing is fun when someone else is doing it' | ||
] | ||
|
||
[self.memory.add(doc) for doc in docs] | ||
|
||
self.assertEqual(self.memory.get_stats()['count'], 2) | ||
|
||
self.memory.clear() | ||
|
||
self.assertEqual(self.memory.get_stats()['count'], 0) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undo the relocation of the HUGGINGFACE line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! there were also some dupes I have now removed.