-
Notifications
You must be signed in to change notification settings - Fork 44.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/weaviate memory #424
Changes from 4 commits
986d32c
da4ba3c
1e63bc5
0ce0c55
97ac802
76a1462
5fe784a
786ee60
3c7767f
96c5e92
453b428
75c4132
f2a6ac5
e3aea6d
67b84b5
b9a4f97
415c1cb
35ecd95
b7d0cc3
5308946
5592dbd
855de18
067e697
2f8cf68
0c3562f
a94b93b
4c7deef
b987cff
005be02
b2bfd39
2678a5a
8916b76
899c815
5122422
03d2032
23b89b8
4cd412c
37a1dc1
b865e2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,7 +140,13 @@ export CUSTOM_SEARCH_ENGINE_ID="YOUR_CUSTOM_SEARCH_ENGINE_ID" | |
|
||
``` | ||
|
||
## 🌲 Pinecone API Key Setup | ||
## Vector based memory provider | ||
Auto-GPT supports two providers for vector-based memory, [Pinecone](https://www.pinecone.io/) and [Weaviate](https://weaviate.io/). To select the provider to use, specify the following in your `.env`: | ||
|
||
``` | ||
MEMORY_PROVIDER="pinecone" # change to "weaviate" to use weaviate as the memory provider | ||
``` | ||
### 🌲 Pinecone API Key Setup | ||
|
||
Pinecone enable a vector based memory so a vast memory can be stored and only relevant memories | ||
are loaded for the agent at any given time. | ||
|
@@ -149,7 +155,7 @@ are loaded for the agent at any given time. | |
2. Choose the `Starter` plan to avoid being charged. | ||
3. Find your API key and region under the default project in the left sidebar. | ||
|
||
### Setting up environment variables | ||
#### Setting up environment variables | ||
For Windows Users: | ||
``` | ||
setx PINECONE_API_KEY "YOUR_PINECONE_API_KEY" | ||
|
@@ -165,6 +171,22 @@ export PINECONE_ENV="Your pinecone region" # something like: us-east4-gcp | |
|
||
Or you can set them in the `.env` file. | ||
|
||
### Weaviate Setup | ||
|
||
[Weaviate](https://weaviate.io/) is an open-source vector database. It allows to store data objects and vector embeddings from ML-models and scales seamlessly to billion of data objects. [An instance of Weaviate can be created locally (using Docker), on Kubernetes or using Weaviate Cloud Services](https://weaviate.io/developers/weaviate/quickstart). | ||
|
||
#### Setting up enviornment variables | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Followed the instructions and ... it did not work 😂 Because it is missing the weaviate client. So, perhaps edit to say something like "First, run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hsm207 HA! I had assumed |
||
|
||
In your `.env` file set the following: | ||
|
||
``` | ||
WEAVIATE_HOST="http://127.0.0.1" # the URL of the running Weaviate instance | ||
WEAVIATE_PORT="8080" | ||
WEAVIATE_USERNAME="your username" | ||
WEAVIATE_PASSWORD="your password" | ||
WEAVIATE_INDEX="Autogpt" # name of the index to create for the application | ||
``` | ||
|
||
## View Memory Usage | ||
|
||
1. View memory usage by using the `--debug` flag :) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from providers.pinecone import PineconeMemory | ||
from providers.weaviate import WeaviateMemory | ||
|
||
class MemoryFactory: | ||
@staticmethod | ||
def get_memory(mem_type): | ||
if mem_type == 'pinecone': | ||
return PineconeMemory() | ||
|
||
if mem_type == 'weaviate': | ||
return WeaviateMemory() | ||
|
||
raise ValueError('Unknown memory provider') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from config import Singleton | ||
import openai | ||
|
||
def get_ada_embedding(text): | ||
text = text.replace("\n", " ") | ||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"] | ||
|
||
|
||
def get_text_from_embedding(embedding): | ||
return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"] | ||
|
||
class Memory(metaclass=Singleton): | ||
def add(self, data): | ||
raise NotImplementedError() | ||
|
||
def get(self, data): | ||
raise NotImplementedError() | ||
|
||
def clear(self): | ||
raise NotImplementedError() | ||
|
||
def get_relevant(self, data, num_relevant=5): | ||
raise NotImplementedError() | ||
|
||
def get_stats(self): | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from config import Config | ||
from providers.memory import Memory, get_ada_embedding | ||
from weaviate import Client | ||
import weaviate | ||
import uuid | ||
|
||
from weaviate.util import generate_uuid5 | ||
|
||
cfg = Config() | ||
|
||
SCHEMA = { | ||
"class": cfg.weaviate_index, | ||
"properties": [ | ||
{ | ||
"name": "raw_text", | ||
"dataType": ["text"], | ||
"description": "original text for the embedding" | ||
} | ||
], | ||
} | ||
|
||
class WeaviateMemory(Memory): | ||
|
||
def __init__(self): | ||
auth_credentials = self._build_auth_credentials() | ||
|
||
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}' | ||
|
||
self.client = Client(url, auth_client_secret=auth_credentials) | ||
|
||
self._create_schema() | ||
|
||
def _create_schema(self): | ||
if not self.client.schema.contains(SCHEMA): | ||
self.client.schema.create_class(SCHEMA) | ||
|
||
@staticmethod | ||
def _build_auth_credentials(): | ||
if cfg.weaviate_username and cfg.weaviate_password: | ||
return weaviate_auth.AuthClientPassword(cfg.weaviate_username, cfg.weaviate_password) | ||
else: | ||
return None | ||
|
||
def add(self, data): | ||
vector = get_ada_embedding(data) | ||
|
||
doc_uuid = generate_uuid5(data, cfg.weaviate_index) | ||
data_object = { | ||
'class': cfg.weaviate_index, | ||
'raw_text': data | ||
} | ||
|
||
with self.client.batch as batch: | ||
batch.add_data_object( | ||
uuid=doc_uuid, | ||
data_object=data_object, | ||
class_name=cfg.weaviate_index, | ||
vector=vector | ||
) | ||
|
||
batch.flush() | ||
|
||
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" | ||
|
||
|
||
def get(self, data): | ||
return self.get_relevant(data, 1) | ||
|
||
|
||
def clear(self): | ||
self.client.schema.delete_all() | ||
|
||
# weaviate does not yet have a neat way to just remove the items in an index | ||
# without removing the entire schema, therefore we need to re-create it | ||
# after a call to delete_all | ||
self._create_schema() | ||
|
||
return 'Obliterated' | ||
|
||
def get_relevant(self, data, num_relevant=5): | ||
query_embedding = get_ada_embedding(data) | ||
try: | ||
results = self.client.query.get(cfg.weaviate_index, ['raw_text']) \ | ||
.with_near_vector({'vector': query_embedding, 'certainty': 0.7}) \ | ||
.with_limit(num_relevant) \ | ||
.do() | ||
|
||
print(results) | ||
|
||
if len(results['data']['Get'][cfg.weaviate_index]) > 0: | ||
return [str(item['raw_text']) for item in results['data']['Get'][cfg.weaviate_index]] | ||
else: | ||
return [] | ||
|
||
except Exception as err: | ||
print(f'Unexpected error {err=}, {type(err)=}') | ||
return [] | ||
|
||
def get_stats(self): | ||
return self.client.index_stats.get(cfg.weaviate_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cs0lar should also mention embedded weaviate here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good spot, thanks! This is now done.