Skip to content

Adds metadata field to chat message history #357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion redisvl/extensions/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
INSERTED_AT_FIELD_NAME: str = "inserted_at"
UPDATED_AT_FIELD_NAME: str = "updated_at"
METADATA_FIELD_NAME: str = "metadata"
METADATA_FIELD_NAME: str = "metadata" # also used in MessageHistory

# EmbeddingsCache
TEXT_FIELD_NAME: str = "text"
Expand Down
7 changes: 6 additions & 1 deletion redisvl/extensions/message_history/base_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
METADATA_FIELD_NAME,
ROLE_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.extensions.message_history.schema import ChatMessage
from redisvl.utils.utils import create_ulid
from redisvl.utils.utils import create_ulid, deserialize


class BaseMessageHistory:
Expand Down Expand Up @@ -111,6 +112,10 @@ def _format_context(
}
if chat_message.tool_call_id is not None:
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id
if chat_message.metadata is not None:
chat_message_dict[METADATA_FIELD_NAME] = deserialize(
chat_message.metadata
)

context.append(chat_message_dict) # type: ignore

Expand Down
8 changes: 7 additions & 1 deletion redisvl/extensions/message_history/message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ID_FIELD_NAME,
METADATA_FIELD_NAME,
ROLE_FIELD_NAME,
SESSION_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
Expand All @@ -15,6 +16,7 @@
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Tag
from redisvl.utils.utils import serialize


class MessageHistory(BaseMessageHistory):
Expand Down Expand Up @@ -98,11 +100,13 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
CONTENT_FIELD_NAME,
TOOL_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
METADATA_FIELD_NAME,
]

query = FilterQuery(
filter_expression=self._default_session_filter,
return_fields=return_fields,
num_results=1000,
)
query.sort_by(TIMESTAMP_FIELD_NAME, asc=True)
messages = self._index.query(query)
Expand Down Expand Up @@ -144,6 +148,7 @@ def get_recent(
CONTENT_FIELD_NAME,
TOOL_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
METADATA_FIELD_NAME,
]

session_filter = (
Expand Down Expand Up @@ -210,7 +215,8 @@ def add_messages(

if TOOL_FIELD_NAME in message:
chat_message.tool_call_id = message[TOOL_FIELD_NAME]

if METADATA_FIELD_NAME in message:
chat_message.metadata = serialize(message[METADATA_FIELD_NAME])
chat_messages.append(chat_message.to_dict())

self._index.load(data=chat_messages, id_field=ID_FIELD_NAME)
Expand Down
10 changes: 8 additions & 2 deletions redisvl/extensions/message_history/schema.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from typing import Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ID_FIELD_NAME,
MESSAGE_VECTOR_FIELD_NAME,
METADATA_FIELD_NAME,
ROLE_FIELD_NAME,
SESSION_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.redis.utils import array_to_buffer
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp
from redisvl.utils.utils import current_timestamp, deserialize


class ChatMessage(BaseModel):
Expand All @@ -33,6 +34,8 @@ class ChatMessage(BaseModel):
"""An optional identifier for a tool call associated with the message."""
vector_field: Optional[List[float]] = Field(default=None)
"""The vector representation of the message content."""
metadata: Optional[str] = Field(default=None)
"""Optional additional data to store alongside the message"""
model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="before")
Expand All @@ -54,6 +57,7 @@ def to_dict(self, dtype: Optional[str] = None) -> Dict:
data[MESSAGE_VECTOR_FIELD_NAME] = array_to_buffer(
data[MESSAGE_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type]
)

return data


Expand All @@ -70,6 +74,7 @@ def from_params(cls, name: str, prefix: str):
{"name": TOOL_FIELD_NAME, "type": "tag"},
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
{"name": SESSION_FIELD_NAME, "type": "tag"},
{"name": METADATA_FIELD_NAME, "type": "text"},
],
)

Expand All @@ -87,6 +92,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str):
{"name": TOOL_FIELD_NAME, "type": "tag"},
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
{"name": SESSION_FIELD_NAME, "type": "tag"},
{"name": METADATA_FIELD_NAME, "type": "text"},
{
"name": MESSAGE_VECTOR_FIELD_NAME,
"type": "vector",
Expand Down
12 changes: 9 additions & 3 deletions redisvl/extensions/message_history/semantic_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CONTENT_FIELD_NAME,
ID_FIELD_NAME,
MESSAGE_VECTOR_FIELD_NAME,
METADATA_FIELD_NAME,
ROLE_FIELD_NAME,
SESSION_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
Expand All @@ -19,7 +20,7 @@
from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, RangeQuery
from redisvl.query.filter import Tag
from redisvl.utils.utils import deprecated_argument, validate_vector_dims
from redisvl.utils.utils import deprecated_argument, serialize, validate_vector_dims
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer


Expand Down Expand Up @@ -149,8 +150,9 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]:
SESSION_FIELD_NAME,
ROLE_FIELD_NAME,
CONTENT_FIELD_NAME,
TOOL_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
TOOL_FIELD_NAME,
METADATA_FIELD_NAME,
]

query = FilterQuery(
Expand Down Expand Up @@ -214,6 +216,7 @@ def get_relevant(
CONTENT_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
TOOL_FIELD_NAME,
METADATA_FIELD_NAME,
]

session_filter = (
Expand Down Expand Up @@ -274,8 +277,9 @@ def get_recent(
SESSION_FIELD_NAME,
ROLE_FIELD_NAME,
CONTENT_FIELD_NAME,
TOOL_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
TOOL_FIELD_NAME,
METADATA_FIELD_NAME,
]

session_filter = (
Expand Down Expand Up @@ -355,6 +359,8 @@ def add_messages(

if TOOL_FIELD_NAME in message:
chat_message.tool_call_id = message[TOOL_FIELD_NAME]
if METADATA_FIELD_NAME in message:
chat_message.metadata = serialize(message[METADATA_FIELD_NAME])

chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype))

Expand Down
89 changes: 74 additions & 15 deletions tests/integration/test_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ def test_standard_add_and_get(standard_history):
"role": "tool",
"content": "tool result 1",
"tool_call_id": "tool call one",
"metadata": {"tool call params": "abc 123"},
}
)
standard_history.add_message(
{
"role": "tool",
"content": "tool result 2",
"tool_call_id": "tool call two",
"metadata": {"tool call params": "abc 456"},
}
)
standard_history.add_message({"role": "user", "content": "third prompt"})
Expand All @@ -121,7 +123,12 @@ def test_standard_add_and_get(standard_history):
partial_context = standard_history.get_recent(top_k=3)
assert len(partial_context) == 3
assert partial_context == [
{"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"},
{
"role": "tool",
"content": "tool result 2",
"tool_call_id": "tool call two",
"metadata": {"tool call params": "abc 456"},
},
{"role": "user", "content": "third prompt"},
{"role": "llm", "content": "third response"},
]
Expand All @@ -133,8 +140,18 @@ def test_standard_add_and_get(standard_history):
{"role": "llm", "content": "first response"},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
{"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"},
{"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"},
{
"role": "tool",
"content": "tool result 1",
"tool_call_id": "tool call one",
"metadata": {"tool call params": "abc 123"},
},
{
"role": "tool",
"content": "tool result 2",
"tool_call_id": "tool call two",
"metadata": {"tool call params": "abc 456"},
},
{"role": "user", "content": "third prompt"},
{"role": "llm", "content": "third response"},
]
Expand All @@ -160,7 +177,11 @@ def test_standard_add_messages(standard_history):
standard_history.add_messages(
[
{"role": "user", "content": "first prompt"},
{"role": "llm", "content": "first response"},
{
"role": "llm",
"content": "first response",
"metadata": {"llm provider": "openai"},
},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
{
Expand All @@ -182,7 +203,11 @@ def test_standard_add_messages(standard_history):
assert len(full_context) == 8
assert full_context == [
{"role": "user", "content": "first prompt"},
{"role": "llm", "content": "first response"},
{
"role": "llm",
"content": "first response",
"metadata": {"llm provider": "openai"},
},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
{"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"},
Expand All @@ -198,17 +223,21 @@ def test_standard_messages_property(standard_history):
{"role": "user", "content": "first prompt"},
{"role": "llm", "content": "first response"},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
{"role": "user", "content": "third prompt"},
{
"role": "llm",
"content": "second response",
"metadata": {"params": "abc"},
},
{"role": "user", "content": "third prompt", "metadata": 42},
]
)

assert standard_history.messages == [
{"role": "user", "content": "first prompt"},
{"role": "llm", "content": "first response"},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
{"role": "user", "content": "third prompt"},
{"role": "llm", "content": "second response", "metadata": {"params": "abc"}},
{"role": "user", "content": "third prompt", "metadata": 42},
]


Expand Down Expand Up @@ -357,7 +386,14 @@ def test_semantic_store_and_get_recent(semantic_history):
semantic_history.add_message(
{"role": "tool", "content": "tool result", "tool_call_id": "tool id"}
)
# test default context history size
semantic_history.add_message(
{
"role": "tool",
"content": "tool result",
"tool_call_id": "tool id",
"metadata": "return value from tool",
}
) # test default context history size
default_context = semantic_history.get_recent()
assert len(default_context) == 5 # 5 is default

Expand All @@ -367,10 +403,10 @@ def test_semantic_store_and_get_recent(semantic_history):

# test larger context history returns full history
too_large_context = semantic_history.get_recent(top_k=100)
assert len(too_large_context) == 9
assert len(too_large_context) == 10

# test that order is maintained
full_context = semantic_history.get_recent(top_k=9)
full_context = semantic_history.get_recent(top_k=10)
assert full_context == [
{"role": "user", "content": "first prompt"},
{"role": "llm", "content": "first response"},
Expand All @@ -381,15 +417,26 @@ def test_semantic_store_and_get_recent(semantic_history):
{"role": "user", "content": "fourth prompt"},
{"role": "llm", "content": "fourth response"},
{"role": "tool", "content": "tool result", "tool_call_id": "tool id"},
{
"role": "tool",
"content": "tool result",
"tool_call_id": "tool id",
"metadata": "return value from tool",
},
]

# test that more recent entries are returned
context = semantic_history.get_recent(top_k=4)
assert context == [
{"role": "llm", "content": "third response"},
{"role": "user", "content": "fourth prompt"},
{"role": "llm", "content": "fourth response"},
{"role": "tool", "content": "tool result", "tool_call_id": "tool id"},
{
"role": "tool",
"content": "tool result",
"tool_call_id": "tool id",
"metadata": "return value from tool",
},
]

# test no entries are returned and no error is raised if top_k == 0
Expand Down Expand Up @@ -422,11 +469,13 @@ def test_semantic_messages_property(semantic_history):
"role": "tool",
"content": "tool result 1",
"tool_call_id": "tool call one",
"metadata": 42,
},
{
"role": "tool",
"content": "tool result 2",
"tool_call_id": "tool call two",
"metadata": [1, 2, 3],
},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
Expand All @@ -437,8 +486,18 @@ def test_semantic_messages_property(semantic_history):
assert semantic_history.messages == [
{"role": "user", "content": "first prompt"},
{"role": "llm", "content": "first response"},
{"role": "tool", "content": "tool result 1", "tool_call_id": "tool call one"},
{"role": "tool", "content": "tool result 2", "tool_call_id": "tool call two"},
{
"role": "tool",
"content": "tool result 1",
"tool_call_id": "tool call one",
"metadata": 42,
},
{
"role": "tool",
"content": "tool result 2",
"tool_call_id": "tool call two",
"metadata": [1, 2, 3],
},
{"role": "user", "content": "second prompt"},
{"role": "llm", "content": "second response"},
{"role": "user", "content": "third prompt"},
Expand Down
Loading
Loading