Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Memory validation fix + core_memory_replace runaway content repeating fix #1616

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion memgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.24"
__version__ = "0.3.25"

from memgpt.client.admin import Admin
from memgpt.client.client import create_client
2 changes: 2 additions & 0 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,8 +662,10 @@ def run(
system_prompt = system if system else None
if human_obj is None:
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
sys.exit(1)
if persona_obj is None:
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
sys.exit(1)

memory = ChatMemory(human=human_obj.text, persona=persona_obj.text, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
Expand Down
6 changes: 3 additions & 3 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text
from memgpt.utils import get_human_text, get_persona_text


def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -259,7 +259,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down Expand Up @@ -729,7 +729,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down
26 changes: 20 additions & 6 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

Expand All @@ -19,6 +20,7 @@
)


# always run validation
class MemoryModule(BaseModel):
"""Base class for memory modules"""

Expand All @@ -35,8 +37,9 @@ def __setattr__(self, name, value):

super().__setattr__(name, value)

@validator("value", always=True)
@validator("value", always=True, check_fields=False)
def check_value_length(cls, v, values):
# TODO: this doesn't run all the time, should fix
if v is not None:
# Fetching the limit from the values dictionary
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
Expand All @@ -50,10 +53,9 @@ def check_value_length(cls, v, values):
raise ValueError("Value must be either a string or a list of strings.")

if length > limit:
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# TODO: add archival memory error?
raise ValueError(error_msg)
return v
raise ValueError(f"Value exceeds {limit} character limit (requested {length}).")

return v

def __len__(self):
return len(str(self))
Expand All @@ -73,10 +75,14 @@ def __init__(self):
self.memory = {}

@classmethod
def load(cls, state: dict):
def load(cls, state: dict, catch_overflow: bool = True):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
# TODO: will cause an error for lists
if catch_overflow and len(value["value"]) >= value["limit"]:
warnings.warn(f"Loaded section {key} exceeds character limit {value['limit']} - increasing specified memory limit.")
value["limit"] = len(value["value"])
obj.memory[key] = MemoryModule(**value)
return obj

Expand All @@ -95,6 +101,14 @@ def to_dict(self):
class ChatMemory(BaseMemory):

def __init__(self, persona: str, human: str, limit: int = 2000):
# TODO: clip if needed
# if persona and len(persona) > limit:
# warnings.warn(f"Persona exceeds {limit} character limit (requested {len(persona)}).")
# persona = persona[:limit]

# if human and len(human) > limit:
# warnings.warn(f"Human exceeds {limit} character limit (requested {len(human)}).")
# human = human[:limit]
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pymemgpt"
version = "0.3.24"
version = "0.3.25"
packages = [
{include = "memgpt"}
]
Expand Down
39 changes: 39 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ def test_create_chat_memory():
assert chat_memory.memory["human"].value == "User"


def test_overflow_chat_memory():
"""Test overflowing an instance of ChatMemory"""
chat_memory = ChatMemory(persona="Chat Agent", human="User")
assert chat_memory.memory["persona"].value == "Chat Agent"
assert chat_memory.memory["human"].value == "User"

# try overflowing via core_memory_append
with pytest.raises(ValueError):
persona_limit = chat_memory.memory["persona"].limit
chat_memory.core_memory_append(name="persona", content="x" * (persona_limit + 1))

# try overflowing via core_memory_replace
with pytest.raises(ValueError):
persona_limit = chat_memory.memory["persona"].limit
chat_memory.core_memory_replace(name="persona", old_content="Chat Agent", new_content="x" * (persona_limit + 1))


def test_dump_memory_as_json(sample_memory):
"""Test dumping ChatMemory as JSON compatible dictionary"""
memory_dict = sample_memory.to_dict()
Expand Down Expand Up @@ -63,3 +80,25 @@ def test_memory_limit_validation(sample_memory):

with pytest.raises(ValueError):
sample_memory.memory["persona"].value = "x" * 3000


def test_corrupted_memory_limit(sample_memory):
"""Test what happens when a memory is stored with a value over the limit

See: https://github.com/cpacker/MemGPT/issues/1567
"""
with pytest.raises(ValueError):
ChatMemory(persona="x" * 3000, human="y" * 3000)

memory_dict = sample_memory.to_dict()
assert memory_dict["persona"]["limit"] == 2000, memory_dict

# overflow the value
memory_dict["persona"]["value"] = "x" * 2500

# by default, this should throw a value error
with pytest.raises(ValueError):
BaseMemory.load(memory_dict, catch_overflow=False)

# if we have overflow protection on, this shouldn't raise a value error
BaseMemory.load(memory_dict)